Added a test utility method to compare protos directly

This commit is contained in:
kinaryml 2022-10-25 23:38:32 -07:00
parent ae5b09e2b2
commit f8a98ccba4
2 changed files with 122 additions and 10 deletions

View File

@ -14,7 +14,12 @@
"""Test util for MediaPipe Tasks.""" """Test util for MediaPipe Tasks."""
import os import os
import difflib
import six
from google.protobuf import descriptor
from google.protobuf import descriptor_pool
from google.protobuf import text_format
from absl import flags from absl import flags
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
@ -53,3 +58,110 @@ def create_calibration_file(file_dir: str,
with open(calibration_file, mode="w") as file: with open(calibration_file, mode="w") as file:
file.write(content) file.write(content)
return calibration_file return calibration_file
def assertProtoEqual(self, a, b, check_initialized=True,
normalize_numbers=False, msg=None):
"""Fails with a useful error if a and b aren't equal.
Comparison of repeated fields matches the semantics of
unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
Args:
self: absltest.testing.parameterized.TestCase
a: proto2 PB instance, or text string representing one.
b: proto2 PB instance -- message.Message or subclass thereof.
check_initialized: boolean, whether to fail if either a or b isn't
initialized.
normalize_numbers: boolean, whether to normalize types and precision of
numbers before comparison.
msg: if specified, is used as the error message on failure.
"""
pool = descriptor_pool.Default()
if isinstance(a, six.string_types):
a = text_format.Merge(a, b.__class__(), descriptor_pool=pool)
for pb in a, b:
if check_initialized:
errors = pb.FindInitializationErrors()
if errors:
self.fail('Initialization errors: %s\n%s' % (errors, pb))
if normalize_numbers:
_normalize_number_fields(pb)
a_str = text_format.MessageToString(a, descriptor_pool=pool)
b_str = text_format.MessageToString(b, descriptor_pool=pool)
# Some Python versions would perform regular diff instead of multi-line
# diff if string is longer than 2**16. We substitute this behavior
# with a call to unified_diff instead to have easier-to-read diffs.
# For context, see: https://bugs.python.org/issue11763.
if len(a_str) < 2**16 and len(b_str) < 2**16:
self.assertMultiLineEqual(a_str, b_str, msg=msg)
else:
diff = ''.join(
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
if diff:
self.fail('%s :\n%s' % (msg, diff))
def _normalize_number_fields(pb):
"""Normalizes types and precisions of number fields in a protocol buffer.
Due to subtleties in the python protocol buffer implementation, it is possible
for values to have different types and precision depending on whether they
were set and retrieved directly or deserialized from a protobuf. This function
normalizes integer values to ints and longs based on width, 32-bit floats to
five digits of precision to account for python always storing them as 64-bit,
and ensures doubles are floating point for when they're set to integers.
Modifies pb in place. Recurses into nested objects.
Args:
pb: proto2 message.
Returns:
the given pb, modified in place.
"""
for desc, values in pb.ListFields():
is_repeated = True
if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED:
is_repeated = False
values = [values]
normalized_values = None
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
# (e.g. the C++ implementation) simpler.
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
descriptor.FieldDescriptor.TYPE_UINT64,
descriptor.FieldDescriptor.TYPE_SINT64):
normalized_values = [int(x) for x in values]
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
descriptor.FieldDescriptor.TYPE_UINT32,
descriptor.FieldDescriptor.TYPE_SINT32,
descriptor.FieldDescriptor.TYPE_ENUM):
normalized_values = [int(x) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
normalized_values = [round(x, 6) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
normalized_values = [round(float(x), 7) for x in values]
if normalized_values is not None:
if is_repeated:
pb.ClearField(desc.name)
getattr(pb, desc.name).extend(normalized_values)
else:
setattr(pb, desc.name, normalized_values[0])
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
desc.message_type.has_options and
desc.message_type.GetOptions().map_entry):
# This is a map, only recurse if the values have a message type.
if (desc.message_type.fields_by_number[2].type ==
descriptor.FieldDescriptor.TYPE_MESSAGE):
for v in six.itervalues(values):
_normalize_number_fields(v)
else:
for v in values:
# recursive step
_normalize_number_fields(v)
return pb

View File

@ -51,8 +51,8 @@ _MAX_RESULTS = 3
# TODO: Port assertProtoEquals # TODO: Port assertProtoEquals
def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument def _assert_proto_equals(self, expected, actual): # pylint: disable=unused-argument
pass test_utils.assertProtoEqual(self, expected, actual)
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
@ -186,7 +186,7 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
# Comparing results. # Comparing results.
_assert_proto_equals(image_result.to_pb2(), _assert_proto_equals(self, image_result.to_pb2(),
expected_classification_result.to_pb2()) expected_classification_result.to_pb2())
# Closes the classifier explicitly when the classifier is not used in # Closes the classifier explicitly when the classifier is not used in
# a context. # a context.
@ -214,7 +214,7 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
# Comparing results. # Comparing results.
_assert_proto_equals(image_result.to_pb2(), _assert_proto_equals(self, image_result.to_pb2(),
expected_classification_result.to_pb2()) expected_classification_result.to_pb2())
def test_classify_succeeds_with_region_of_interest(self): def test_classify_succeeds_with_region_of_interest(self):
@ -232,7 +232,7 @@ class ImageClassifierTest(parameterized.TestCase):
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(test_image, roi) image_result = classifier.classify(test_image, roi)
# Comparing results. # Comparing results.
_assert_proto_equals(image_result.to_pb2(), _assert_proto_equals(self, image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2()) _generate_soccer_ball_results(0).to_pb2())
def test_score_threshold_option(self): def test_score_threshold_option(self):
@ -401,7 +401,7 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
self.test_image, timestamp) self.test_image, timestamp)
_assert_proto_equals(classification_result.to_pb2(), _assert_proto_equals(self, classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2()) _generate_burger_results(timestamp).to_pb2())
def test_classify_for_video_succeeds_with_region_of_interest(self): def test_classify_for_video_succeeds_with_region_of_interest(self):
@ -420,8 +420,8 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
test_image, timestamp, roi) test_image, timestamp, roi)
self.assertEqual(classification_result, _assert_proto_equals(self, classification_result.to_pb2(),
_generate_soccer_ball_results(timestamp)) _generate_soccer_ball_results(timestamp).to_pb2())
def test_calling_classify_in_live_stream_mode(self): def test_calling_classify_in_live_stream_mode(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
@ -463,7 +463,7 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
_assert_proto_equals(result.to_pb2(), _assert_proto_equals(self, result.to_pb2(),
expected_result_fn(timestamp_ms).to_pb2()) expected_result_fn(timestamp_ms).to_pb2())
self.assertTrue( self.assertTrue(
np.array_equal(output_image.numpy_view(), np.array_equal(output_image.numpy_view(),
@ -493,7 +493,7 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
_assert_proto_equals(result.to_pb2(), _assert_proto_equals(self, result.to_pb2(),
_generate_soccer_ball_results(timestamp_ms).to_pb2()) _generate_soccer_ball_results(timestamp_ms).to_pb2())
self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height) self.assertEqual(output_image.height, test_image.height)