Merge pull request #3800 from kinaryml:python-test-proto-equals

PiperOrigin-RevId: 485340924
This commit is contained in:
Copybara-Service 2022-11-01 09:42:19 -07:00
commit 6e0397b226
3 changed files with 155 additions and 25 deletions

View File

@ -27,5 +27,8 @@ py_library(
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__", "//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
"//mediapipe/tasks:internal", "//mediapipe/tasks:internal",
], ],
deps = ["//mediapipe/python:_framework_bindings"], deps = [
"//mediapipe/python:_framework_bindings",
"@com_google_protobuf//:protobuf_python",
],
) )

View File

@ -13,9 +13,15 @@
# limitations under the License. # limitations under the License.
"""Test util for MediaPipe Tasks.""" """Test util for MediaPipe Tasks."""
import difflib
import os import os
from absl import flags from absl import flags
import six
from google.protobuf import descriptor
from google.protobuf import descriptor_pool
from google.protobuf import text_format
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module from mediapipe.python._framework_bindings import image_frame as image_frame_module
@ -53,3 +59,126 @@ def create_calibration_file(file_dir: str,
with open(calibration_file, mode="w") as file: with open(calibration_file, mode="w") as file:
file.write(content) file.write(content)
return calibration_file return calibration_file
def assert_proto_equals(self,
a,
b,
check_initialized=True,
normalize_numbers=True,
msg=None):
"""assert_proto_equals() is useful for unit tests.
It produces much more helpful output than assertEqual() for proto2 messages.
Fails with a useful error if a and b aren't equal. Comparison of repeated
fields matches the semantics of unittest.TestCase.assertEqual(), ie order and
extra duplicates fields matter.
This is a fork of https://github.com/tensorflow/tensorflow/blob/
master/tensorflow/python/util/protobuf/compare.py#L73. We use slightly
different rounding cutoffs to support Mac usage.
Args:
self: absltest.testing.parameterized.TestCase
a: proto2 PB instance, or text string representing one.
b: proto2 PB instance -- message.Message or subclass thereof.
check_initialized: boolean, whether to fail if either a or b isn't
initialized.
normalize_numbers: boolean, whether to normalize types and precision of
numbers before comparison.
msg: if specified, is used as the error message on failure.
"""
pool = descriptor_pool.Default()
if isinstance(a, six.string_types):
a = text_format.Parse(a, b.__class__(), descriptor_pool=pool)
for pb in a, b:
if check_initialized:
errors = pb.FindInitializationErrors()
if errors:
self.fail("Initialization errors: %s\n%s" % (errors, pb))
if normalize_numbers:
_normalize_number_fields(pb)
a_str = text_format.MessageToString(a, descriptor_pool=pool)
b_str = text_format.MessageToString(b, descriptor_pool=pool)
# Some Python versions would perform regular diff instead of multi-line
# diff if string is longer than 2**16. We substitute this behavior
# with a call to unified_diff instead to have easier-to-read diffs.
# For context, see: https://bugs.python.org/issue11763.
if len(a_str) < 2**16 and len(b_str) < 2**16:
self.assertMultiLineEqual(a_str, b_str, msg=msg)
else:
diff = "".join(
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
if diff:
self.fail("%s :\n%s" % (msg, diff))
def _normalize_number_fields(pb):
"""Normalizes types and precisions of number fields in a protocol buffer.
Due to subtleties in the python protocol buffer implementation, it is possible
for values to have different types and precision depending on whether they
were set and retrieved directly or deserialized from a protobuf. This function
normalizes integer values to ints and longs based on width, 32-bit floats to
five digits of precision to account for python always storing them as 64-bit,
and ensures doubles are floating point for when they're set to integers.
Modifies pb in place. Recurses into nested objects. https://github.com/tensorf
low/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118
Args:
pb: proto2 message.
Returns:
the given pb, modified in place.
"""
for desc, values in pb.ListFields():
is_repeated = True
if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED:
is_repeated = False
values = [values]
normalized_values = None
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
# (e.g. the C++ implementation) simpler.
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
descriptor.FieldDescriptor.TYPE_UINT64,
descriptor.FieldDescriptor.TYPE_SINT64):
normalized_values = [int(x) for x in values]
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
descriptor.FieldDescriptor.TYPE_UINT32,
descriptor.FieldDescriptor.TYPE_SINT32,
descriptor.FieldDescriptor.TYPE_ENUM):
normalized_values = [int(x) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
normalized_values = [round(x, 5) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
normalized_values = [round(float(x), 6) for x in values]
if normalized_values is not None:
if is_repeated:
pb.ClearField(desc.name)
getattr(pb, desc.name).extend(normalized_values)
else:
setattr(pb, desc.name, normalized_values[0])
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
desc.message_type.has_options and
desc.message_type.GetOptions().map_entry):
# This is a map, only recurse if the values have a message type.
if (desc.message_type.fields_by_number[2].type ==
descriptor.FieldDescriptor.TYPE_MESSAGE):
for v in six.itervalues(values):
_normalize_number_fields(v)
else:
for v in values:
# recursive step
_normalize_number_fields(v)
return pb

View File

@ -50,11 +50,6 @@ _SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3 _MAX_RESULTS = 3
# TODO: Port assertProtoEquals
def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument
pass
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[ return _ClassificationResult(classifications=[
_Classifications( _Classifications(
@ -74,22 +69,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
categories=[ categories=[
_Category( _Category(
index=934, index=934,
score=0.7939587831497192, score=0.793959,
display_name='', display_name='',
category_name='cheeseburger'), category_name='cheeseburger'),
_Category( _Category(
index=932, index=932,
score=0.02739289402961731, score=0.0273929,
display_name='', display_name='',
category_name='bagel'), category_name='bagel'),
_Category( _Category(
index=925, index=925,
score=0.01934075355529785, score=0.0193408,
display_name='', display_name='',
category_name='guacamole'), category_name='guacamole'),
_Category( _Category(
index=963, index=963,
score=0.006327860057353973, score=0.00632786,
display_name='', display_name='',
category_name='meat loaf') category_name='meat loaf')
], ],
@ -108,7 +103,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
categories=[ categories=[
_Category( _Category(
index=806, index=806,
score=0.9965274930000305, score=0.996527,
display_name='', display_name='',
category_name='soccer ball') category_name='soccer ball')
], ],
@ -186,7 +181,7 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
# Comparing results. # Comparing results.
_assert_proto_equals(image_result.to_pb2(), test_utils.assert_proto_equals(self, image_result.to_pb2(),
expected_classification_result.to_pb2()) expected_classification_result.to_pb2())
# Closes the classifier explicitly when the classifier is not used in # Closes the classifier explicitly when the classifier is not used in
# a context. # a context.
@ -214,7 +209,7 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
# Comparing results. # Comparing results.
_assert_proto_equals(image_result.to_pb2(), test_utils.assert_proto_equals(self, image_result.to_pb2(),
expected_classification_result.to_pb2()) expected_classification_result.to_pb2())
def test_classify_succeeds_with_region_of_interest(self): def test_classify_succeeds_with_region_of_interest(self):
@ -232,7 +227,7 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(test_image, roi) image_result = classifier.classify(test_image, roi)
# Comparing results. # Comparing results.
_assert_proto_equals(image_result.to_pb2(), test_utils.assert_proto_equals(self, image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2()) _generate_soccer_ball_results(0).to_pb2())
def test_score_threshold_option(self): def test_score_threshold_option(self):
@ -401,7 +396,8 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
self.test_image, timestamp) self.test_image, timestamp)
_assert_proto_equals(classification_result.to_pb2(), test_utils.assert_proto_equals(
self, classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2()) _generate_burger_results(timestamp).to_pb2())
def test_classify_for_video_succeeds_with_region_of_interest(self): def test_classify_for_video_succeeds_with_region_of_interest(self):
@ -420,8 +416,9 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
test_image, timestamp, roi) test_image, timestamp, roi)
self.assertEqual(classification_result, test_utils.assert_proto_equals(
_generate_soccer_ball_results(timestamp)) self, classification_result.to_pb2(),
_generate_soccer_ball_results(timestamp).to_pb2())
def test_calling_classify_in_live_stream_mode(self): def test_calling_classify_in_live_stream_mode(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
@ -463,7 +460,7 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
_assert_proto_equals(result.to_pb2(), test_utils.assert_proto_equals(self, result.to_pb2(),
expected_result_fn(timestamp_ms).to_pb2()) expected_result_fn(timestamp_ms).to_pb2())
self.assertTrue( self.assertTrue(
np.array_equal(output_image.numpy_view(), np.array_equal(output_image.numpy_view(),
@ -493,7 +490,8 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
_assert_proto_equals(result.to_pb2(), test_utils.assert_proto_equals(
self, result.to_pb2(),
_generate_soccer_ball_results(timestamp_ms).to_pb2()) _generate_soccer_ball_results(timestamp_ms).to_pb2())
self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height) self.assertEqual(output_image.height, test_image.height)