Updated docstrings and removed the redundant private helper method

This commit is contained in:
kinaryml 2022-10-26 23:56:54 -07:00
parent f8a98ccba4
commit 8194513934
2 changed files with 30 additions and 27 deletions

View File

@ -61,10 +61,13 @@ def create_calibration_file(file_dir: str,
def assertProtoEqual(self, a, b, check_initialized=True,
normalize_numbers=False, msg=None):
"""Fails with a useful error if a and b aren't equal.
Comparison of repeated fields matches the semantics of
normalize_numbers=True, msg=None):
"""assertProtoEqual() is useful for unit tests. It produces much more helpful
output than assertEqual() for proto2 messages. Fails with a useful error if a
and b aren't equal. Comparison of repeated fields matches the semantics of
unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L73
Args:
self: absltest.testing.parameterized.TestCase
a: proto2 PB instance, or text string representing one.
@ -112,6 +115,8 @@ def _normalize_number_fields(pb):
five digits of precision to account for python always storing them as 64-bit,
and ensures doubles are floating point for when they're set to integers.
Modifies pb in place. Recurses into nested objects.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118
Args:
pb: proto2 message.
Returns:

View File

@ -50,11 +50,6 @@ _SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3
# TODO: Port assertProtoEquals
def _assert_proto_equals(self, expected, actual): # pylint: disable=unused-argument
test_utils.assertProtoEqual(self, expected, actual)
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[
_Classifications(
@ -74,22 +69,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
categories=[
_Category(
index=934,
score=0.7939587831497192,
score=0.793959,
display_name='',
category_name='cheeseburger'),
_Category(
index=932,
score=0.02739289402961731,
score=0.0273929,
display_name='',
category_name='bagel'),
_Category(
index=925,
score=0.01934075355529785,
score=0.0193408,
display_name='',
category_name='guacamole'),
_Category(
index=963,
score=0.006327860057353973,
score=0.00632786,
display_name='',
category_name='meat loaf')
],
@ -108,7 +103,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
categories=[
_Category(
index=806,
score=0.9965274930000305,
score=0.996527,
display_name='',
category_name='soccer ball')
],
@ -186,8 +181,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
# Comparing results.
_assert_proto_equals(self, image_result.to_pb2(),
expected_classification_result.to_pb2())
test_utils.assertProtoEqual(self, image_result.to_pb2(),
expected_classification_result.to_pb2())
# Closes the classifier explicitly when the classifier is not used in
# a context.
classifier.close()
@ -214,8 +209,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
# Comparing results.
_assert_proto_equals(self, image_result.to_pb2(),
expected_classification_result.to_pb2())
test_utils.assertProtoEqual(self, image_result.to_pb2(),
expected_classification_result.to_pb2())
def test_classify_succeeds_with_region_of_interest(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
@ -232,8 +227,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input.
image_result = classifier.classify(test_image, roi)
# Comparing results.
_assert_proto_equals(self, image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2())
test_utils.assertProtoEqual(self, image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2())
def test_score_threshold_option(self):
custom_classifier_options = _ClassifierOptions(
@ -401,8 +396,9 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
self.test_image, timestamp)
_assert_proto_equals(self, classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2())
test_utils.assertProtoEqual(
self, classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2())
def test_classify_for_video_succeeds_with_region_of_interest(self):
custom_classifier_options = _ClassifierOptions(max_results=1)
@ -420,8 +416,9 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
test_image, timestamp, roi)
_assert_proto_equals(self, classification_result.to_pb2(),
_generate_soccer_ball_results(timestamp).to_pb2())
test_utils.assertProtoEqual(
self, classification_result.to_pb2(),
_generate_soccer_ball_results(timestamp).to_pb2())
def test_calling_classify_in_live_stream_mode(self):
options = _ImageClassifierOptions(
@ -463,8 +460,8 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int):
_assert_proto_equals(self, result.to_pb2(),
expected_result_fn(timestamp_ms).to_pb2())
test_utils.assertProtoEqual(
self, result.to_pb2(), expected_result_fn(timestamp_ms).to_pb2())
self.assertTrue(
np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view()))
@ -493,8 +490,9 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int):
_assert_proto_equals(self, result.to_pb2(),
_generate_soccer_ball_results(timestamp_ms).to_pb2())
test_utils.assertProtoEqual(
self, result.to_pb2(),
_generate_soccer_ball_results(timestamp_ms).to_pb2())
self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height)
self.assertLess(observed_timestamp_ms, timestamp_ms)