Updated docstrings and removed the redundant private helper method

This commit is contained in:
kinaryml 2022-10-26 23:56:54 -07:00
parent f8a98ccba4
commit 8194513934
2 changed files with 30 additions and 27 deletions

View File

@ -61,10 +61,13 @@ def create_calibration_file(file_dir: str,
def assertProtoEqual(self, a, b, check_initialized=True, def assertProtoEqual(self, a, b, check_initialized=True,
normalize_numbers=False, msg=None): normalize_numbers=True, msg=None):
"""Fails with a useful error if a and b aren't equal. """assertProtoEqual() is useful for unit tests. It produces much more helpful
Comparison of repeated fields matches the semantics of output than assertEqual() for proto2 messages. Fails with a useful error if a
and b aren't equal. Comparison of repeated fields matches the semantics of
unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter. unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L73
Args: Args:
self: absltest.testing.parameterized.TestCase self: absltest.testing.parameterized.TestCase
a: proto2 PB instance, or text string representing one. a: proto2 PB instance, or text string representing one.
@ -112,6 +115,8 @@ def _normalize_number_fields(pb):
five digits of precision to account for python always storing them as 64-bit, five digits of precision to account for python always storing them as 64-bit,
and ensures doubles are floating point for when they're set to integers. and ensures doubles are floating point for when they're set to integers.
Modifies pb in place. Recurses into nested objects. Modifies pb in place. Recurses into nested objects.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118
Args: Args:
pb: proto2 message. pb: proto2 message.
Returns: Returns:

View File

@ -50,11 +50,6 @@ _SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3 _MAX_RESULTS = 3
# TODO: Port assertProtoEquals
def _assert_proto_equals(self, expected, actual): # pylint: disable=unused-argument
test_utils.assertProtoEqual(self, expected, actual)
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[ return _ClassificationResult(classifications=[
_Classifications( _Classifications(
@ -74,22 +69,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
categories=[ categories=[
_Category( _Category(
index=934, index=934,
score=0.7939587831497192, score=0.793959,
display_name='', display_name='',
category_name='cheeseburger'), category_name='cheeseburger'),
_Category( _Category(
index=932, index=932,
score=0.02739289402961731, score=0.0273929,
display_name='', display_name='',
category_name='bagel'), category_name='bagel'),
_Category( _Category(
index=925, index=925,
score=0.01934075355529785, score=0.0193408,
display_name='', display_name='',
category_name='guacamole'), category_name='guacamole'),
_Category( _Category(
index=963, index=963,
score=0.006327860057353973, score=0.00632786,
display_name='', display_name='',
category_name='meat loaf') category_name='meat loaf')
], ],
@ -108,7 +103,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
categories=[ categories=[
_Category( _Category(
index=806, index=806,
score=0.9965274930000305, score=0.996527,
display_name='', display_name='',
category_name='soccer ball') category_name='soccer ball')
], ],
@ -186,8 +181,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
# Comparing results. # Comparing results.
_assert_proto_equals(self, image_result.to_pb2(), test_utils.assertProtoEqual(self, image_result.to_pb2(),
expected_classification_result.to_pb2()) expected_classification_result.to_pb2())
# Closes the classifier explicitly when the classifier is not used in # Closes the classifier explicitly when the classifier is not used in
# a context. # a context.
classifier.close() classifier.close()
@ -214,8 +209,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
# Comparing results. # Comparing results.
_assert_proto_equals(self, image_result.to_pb2(), test_utils.assertProtoEqual(self, image_result.to_pb2(),
expected_classification_result.to_pb2()) expected_classification_result.to_pb2())
def test_classify_succeeds_with_region_of_interest(self): def test_classify_succeeds_with_region_of_interest(self):
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
@ -232,8 +227,8 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(test_image, roi) image_result = classifier.classify(test_image, roi)
# Comparing results. # Comparing results.
_assert_proto_equals(self, image_result.to_pb2(), test_utils.assertProtoEqual(self, image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2()) _generate_soccer_ball_results(0).to_pb2())
def test_score_threshold_option(self): def test_score_threshold_option(self):
custom_classifier_options = _ClassifierOptions( custom_classifier_options = _ClassifierOptions(
@ -401,8 +396,9 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
self.test_image, timestamp) self.test_image, timestamp)
_assert_proto_equals(self, classification_result.to_pb2(), test_utils.assertProtoEqual(
_generate_burger_results(timestamp).to_pb2()) self, classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2())
def test_classify_for_video_succeeds_with_region_of_interest(self): def test_classify_for_video_succeeds_with_region_of_interest(self):
custom_classifier_options = _ClassifierOptions(max_results=1) custom_classifier_options = _ClassifierOptions(max_results=1)
@ -420,8 +416,9 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
test_image, timestamp, roi) test_image, timestamp, roi)
_assert_proto_equals(self, classification_result.to_pb2(), test_utils.assertProtoEqual(
_generate_soccer_ball_results(timestamp).to_pb2()) self, classification_result.to_pb2(),
_generate_soccer_ball_results(timestamp).to_pb2())
def test_calling_classify_in_live_stream_mode(self): def test_calling_classify_in_live_stream_mode(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
@ -463,8 +460,8 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
_assert_proto_equals(self, result.to_pb2(), test_utils.assertProtoEqual(
expected_result_fn(timestamp_ms).to_pb2()) self, result.to_pb2(), expected_result_fn(timestamp_ms).to_pb2())
self.assertTrue( self.assertTrue(
np.array_equal(output_image.numpy_view(), np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view())) self.test_image.numpy_view()))
@ -493,8 +490,9 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
_assert_proto_equals(self, result.to_pb2(), test_utils.assertProtoEqual(
_generate_soccer_ball_results(timestamp_ms).to_pb2()) self, result.to_pb2(),
_generate_soccer_ball_results(timestamp_ms).to_pb2())
self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height) self.assertEqual(output_image.height, test_image.height)
self.assertLess(observed_timestamp_ms, timestamp_ms) self.assertLess(observed_timestamp_ms, timestamp_ms)