Updated docstrings and removed the redundant private helper method
This commit is contained in:
parent
f8a98ccba4
commit
8194513934
|
@ -61,10 +61,13 @@ def create_calibration_file(file_dir: str,
|
||||||
|
|
||||||
|
|
||||||
def assertProtoEqual(self, a, b, check_initialized=True,
|
def assertProtoEqual(self, a, b, check_initialized=True,
|
||||||
normalize_numbers=False, msg=None):
|
normalize_numbers=True, msg=None):
|
||||||
"""Fails with a useful error if a and b aren't equal.
|
"""assertProtoEqual() is useful for unit tests. It produces much more helpful
|
||||||
Comparison of repeated fields matches the semantics of
|
output than assertEqual() for proto2 messages. Fails with a useful error if a
|
||||||
|
and b aren't equal. Comparison of repeated fields matches the semantics of
|
||||||
unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
|
unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
|
||||||
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L73
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
self: absltest.testing.parameterized.TestCase
|
self: absltest.testing.parameterized.TestCase
|
||||||
a: proto2 PB instance, or text string representing one.
|
a: proto2 PB instance, or text string representing one.
|
||||||
|
@ -112,6 +115,8 @@ def _normalize_number_fields(pb):
|
||||||
five digits of precision to account for python always storing them as 64-bit,
|
five digits of precision to account for python always storing them as 64-bit,
|
||||||
and ensures doubles are floating point for when they're set to integers.
|
and ensures doubles are floating point for when they're set to integers.
|
||||||
Modifies pb in place. Recurses into nested objects.
|
Modifies pb in place. Recurses into nested objects.
|
||||||
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pb: proto2 message.
|
pb: proto2 message.
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -50,11 +50,6 @@ _SCORE_THRESHOLD = 0.5
|
||||||
_MAX_RESULTS = 3
|
_MAX_RESULTS = 3
|
||||||
|
|
||||||
|
|
||||||
# TODO: Port assertProtoEquals
|
|
||||||
def _assert_proto_equals(self, expected, actual): # pylint: disable=unused-argument
|
|
||||||
test_utils.assertProtoEqual(self, expected, actual)
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
return _ClassificationResult(classifications=[
|
return _ClassificationResult(classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
|
@ -74,22 +69,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=934,
|
index=934,
|
||||||
score=0.7939587831497192,
|
score=0.793959,
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='cheeseburger'),
|
category_name='cheeseburger'),
|
||||||
_Category(
|
_Category(
|
||||||
index=932,
|
index=932,
|
||||||
score=0.02739289402961731,
|
score=0.0273929,
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='bagel'),
|
category_name='bagel'),
|
||||||
_Category(
|
_Category(
|
||||||
index=925,
|
index=925,
|
||||||
score=0.01934075355529785,
|
score=0.0193408,
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='guacamole'),
|
category_name='guacamole'),
|
||||||
_Category(
|
_Category(
|
||||||
index=963,
|
index=963,
|
||||||
score=0.006327860057353973,
|
score=0.00632786,
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='meat loaf')
|
category_name='meat loaf')
|
||||||
],
|
],
|
||||||
|
@ -108,7 +103,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=806,
|
index=806,
|
||||||
score=0.9965274930000305,
|
score=0.996527,
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='soccer ball')
|
category_name='soccer ball')
|
||||||
],
|
],
|
||||||
|
@ -186,7 +181,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
_assert_proto_equals(self, image_result.to_pb2(),
|
test_utils.assertProtoEqual(self, image_result.to_pb2(),
|
||||||
expected_classification_result.to_pb2())
|
expected_classification_result.to_pb2())
|
||||||
# Closes the classifier explicitly when the classifier is not used in
|
# Closes the classifier explicitly when the classifier is not used in
|
||||||
# a context.
|
# a context.
|
||||||
|
@ -214,7 +209,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
_assert_proto_equals(self, image_result.to_pb2(),
|
test_utils.assertProtoEqual(self, image_result.to_pb2(),
|
||||||
expected_classification_result.to_pb2())
|
expected_classification_result.to_pb2())
|
||||||
|
|
||||||
def test_classify_succeeds_with_region_of_interest(self):
|
def test_classify_succeeds_with_region_of_interest(self):
|
||||||
|
@ -232,7 +227,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(test_image, roi)
|
image_result = classifier.classify(test_image, roi)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
_assert_proto_equals(self, image_result.to_pb2(),
|
test_utils.assertProtoEqual(self, image_result.to_pb2(),
|
||||||
_generate_soccer_ball_results(0).to_pb2())
|
_generate_soccer_ball_results(0).to_pb2())
|
||||||
|
|
||||||
def test_score_threshold_option(self):
|
def test_score_threshold_option(self):
|
||||||
|
@ -401,7 +396,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classification_result = classifier.classify_for_video(
|
classification_result = classifier.classify_for_video(
|
||||||
self.test_image, timestamp)
|
self.test_image, timestamp)
|
||||||
_assert_proto_equals(self, classification_result.to_pb2(),
|
test_utils.assertProtoEqual(
|
||||||
|
self, classification_result.to_pb2(),
|
||||||
_generate_burger_results(timestamp).to_pb2())
|
_generate_burger_results(timestamp).to_pb2())
|
||||||
|
|
||||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||||
|
@ -420,7 +416,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classification_result = classifier.classify_for_video(
|
classification_result = classifier.classify_for_video(
|
||||||
test_image, timestamp, roi)
|
test_image, timestamp, roi)
|
||||||
_assert_proto_equals(self, classification_result.to_pb2(),
|
test_utils.assertProtoEqual(
|
||||||
|
self, classification_result.to_pb2(),
|
||||||
_generate_soccer_ball_results(timestamp).to_pb2())
|
_generate_soccer_ball_results(timestamp).to_pb2())
|
||||||
|
|
||||||
def test_calling_classify_in_live_stream_mode(self):
|
def test_calling_classify_in_live_stream_mode(self):
|
||||||
|
@ -463,8 +460,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
_assert_proto_equals(self, result.to_pb2(),
|
test_utils.assertProtoEqual(
|
||||||
expected_result_fn(timestamp_ms).to_pb2())
|
self, result.to_pb2(), expected_result_fn(timestamp_ms).to_pb2())
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.array_equal(output_image.numpy_view(),
|
np.array_equal(output_image.numpy_view(),
|
||||||
self.test_image.numpy_view()))
|
self.test_image.numpy_view()))
|
||||||
|
@ -493,7 +490,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
_assert_proto_equals(self, result.to_pb2(),
|
test_utils.assertProtoEqual(
|
||||||
|
self, result.to_pb2(),
|
||||||
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
||||||
self.assertEqual(output_image.width, test_image.width)
|
self.assertEqual(output_image.width, test_image.width)
|
||||||
self.assertEqual(output_image.height, test_image.height)
|
self.assertEqual(output_image.height, test_image.height)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user