Updated docstrings and removed the redundant private helper method
This commit is contained in:
parent
f8a98ccba4
commit
8194513934
|
@ -61,10 +61,13 @@ def create_calibration_file(file_dir: str,
|
|||
|
||||
|
||||
def assertProtoEqual(self, a, b, check_initialized=True,
|
||||
normalize_numbers=False, msg=None):
|
||||
"""Fails with a useful error if a and b aren't equal.
|
||||
Comparison of repeated fields matches the semantics of
|
||||
normalize_numbers=True, msg=None):
|
||||
"""assertProtoEqual() is useful for unit tests. It produces much more helpful
|
||||
output than assertEqual() for proto2 messages. Fails with a useful error if a
|
||||
and b aren't equal. Comparison of repeated fields matches the semantics of
|
||||
unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L73
|
||||
|
||||
Args:
|
||||
self: absltest.testing.parameterized.TestCase
|
||||
a: proto2 PB instance, or text string representing one.
|
||||
|
@ -112,6 +115,8 @@ def _normalize_number_fields(pb):
|
|||
five digits of precision to account for python always storing them as 64-bit,
|
||||
and ensures doubles are floating point for when they're set to integers.
|
||||
Modifies pb in place. Recurses into nested objects.
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118
|
||||
|
||||
Args:
|
||||
pb: proto2 message.
|
||||
Returns:
|
||||
|
|
|
@ -50,11 +50,6 @@ _SCORE_THRESHOLD = 0.5
|
|||
_MAX_RESULTS = 3
|
||||
|
||||
|
||||
# TODO: Port assertProtoEquals
|
||||
def _assert_proto_equals(self, expected, actual): # pylint: disable=unused-argument
|
||||
test_utils.assertProtoEqual(self, expected, actual)
|
||||
|
||||
|
||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
return _ClassificationResult(classifications=[
|
||||
_Classifications(
|
||||
|
@ -74,22 +69,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
|||
categories=[
|
||||
_Category(
|
||||
index=934,
|
||||
score=0.7939587831497192,
|
||||
score=0.793959,
|
||||
display_name='',
|
||||
category_name='cheeseburger'),
|
||||
_Category(
|
||||
index=932,
|
||||
score=0.02739289402961731,
|
||||
score=0.0273929,
|
||||
display_name='',
|
||||
category_name='bagel'),
|
||||
_Category(
|
||||
index=925,
|
||||
score=0.01934075355529785,
|
||||
score=0.0193408,
|
||||
display_name='',
|
||||
category_name='guacamole'),
|
||||
_Category(
|
||||
index=963,
|
||||
score=0.006327860057353973,
|
||||
score=0.00632786,
|
||||
display_name='',
|
||||
category_name='meat loaf')
|
||||
],
|
||||
|
@ -108,7 +103,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
|
|||
categories=[
|
||||
_Category(
|
||||
index=806,
|
||||
score=0.9965274930000305,
|
||||
score=0.996527,
|
||||
display_name='',
|
||||
category_name='soccer ball')
|
||||
],
|
||||
|
@ -186,8 +181,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
# Comparing results.
|
||||
_assert_proto_equals(self, image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
test_utils.assertProtoEqual(self, image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
# Closes the classifier explicitly when the classifier is not used in
|
||||
# a context.
|
||||
classifier.close()
|
||||
|
@ -214,8 +209,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
# Comparing results.
|
||||
_assert_proto_equals(self, image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
test_utils.assertProtoEqual(self, image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
|
||||
def test_classify_succeeds_with_region_of_interest(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
|
@ -232,8 +227,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(test_image, roi)
|
||||
# Comparing results.
|
||||
_assert_proto_equals(self, image_result.to_pb2(),
|
||||
_generate_soccer_ball_results(0).to_pb2())
|
||||
test_utils.assertProtoEqual(self, image_result.to_pb2(),
|
||||
_generate_soccer_ball_results(0).to_pb2())
|
||||
|
||||
def test_score_threshold_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
|
@ -401,8 +396,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
self.test_image, timestamp)
|
||||
_assert_proto_equals(self, classification_result.to_pb2(),
|
||||
_generate_burger_results(timestamp).to_pb2())
|
||||
test_utils.assertProtoEqual(
|
||||
self, classification_result.to_pb2(),
|
||||
_generate_burger_results(timestamp).to_pb2())
|
||||
|
||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||
|
@ -420,8 +416,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
test_image, timestamp, roi)
|
||||
_assert_proto_equals(self, classification_result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp).to_pb2())
|
||||
test_utils.assertProtoEqual(
|
||||
self, classification_result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp).to_pb2())
|
||||
|
||||
def test_calling_classify_in_live_stream_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
|
@ -463,8 +460,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
|
||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
_assert_proto_equals(self, result.to_pb2(),
|
||||
expected_result_fn(timestamp_ms).to_pb2())
|
||||
test_utils.assertProtoEqual(
|
||||
self, result.to_pb2(), expected_result_fn(timestamp_ms).to_pb2())
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
|
@ -493,8 +490,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
|
||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
_assert_proto_equals(self, result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
||||
test_utils.assertProtoEqual(
|
||||
self, result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
||||
self.assertEqual(output_image.width, test_image.width)
|
||||
self.assertEqual(output_image.height, test_image.height)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
|
|
Loading…
Reference in New Issue
Block a user