Added a test utility method to compare protos directly
This commit is contained in:
parent
ae5b09e2b2
commit
f8a98ccba4
|
@ -14,7 +14,12 @@
|
||||||
"""Test util for MediaPipe Tasks."""
|
"""Test util for MediaPipe Tasks."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import difflib
|
||||||
|
import six
|
||||||
|
|
||||||
|
from google.protobuf import descriptor
|
||||||
|
from google.protobuf import descriptor_pool
|
||||||
|
from google.protobuf import text_format
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
@ -53,3 +58,110 @@ def create_calibration_file(file_dir: str,
|
||||||
with open(calibration_file, mode="w") as file:
|
with open(calibration_file, mode="w") as file:
|
||||||
file.write(content)
|
file.write(content)
|
||||||
return calibration_file
|
return calibration_file
|
||||||
|
|
||||||
|
|
||||||
|
def assertProtoEqual(self, a, b, check_initialized=True,
|
||||||
|
normalize_numbers=False, msg=None):
|
||||||
|
"""Fails with a useful error if a and b aren't equal.
|
||||||
|
Comparison of repeated fields matches the semantics of
|
||||||
|
unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
|
||||||
|
Args:
|
||||||
|
self: absltest.testing.parameterized.TestCase
|
||||||
|
a: proto2 PB instance, or text string representing one.
|
||||||
|
b: proto2 PB instance -- message.Message or subclass thereof.
|
||||||
|
check_initialized: boolean, whether to fail if either a or b isn't
|
||||||
|
initialized.
|
||||||
|
normalize_numbers: boolean, whether to normalize types and precision of
|
||||||
|
numbers before comparison.
|
||||||
|
msg: if specified, is used as the error message on failure.
|
||||||
|
"""
|
||||||
|
pool = descriptor_pool.Default()
|
||||||
|
if isinstance(a, six.string_types):
|
||||||
|
a = text_format.Merge(a, b.__class__(), descriptor_pool=pool)
|
||||||
|
|
||||||
|
for pb in a, b:
|
||||||
|
if check_initialized:
|
||||||
|
errors = pb.FindInitializationErrors()
|
||||||
|
if errors:
|
||||||
|
self.fail('Initialization errors: %s\n%s' % (errors, pb))
|
||||||
|
if normalize_numbers:
|
||||||
|
_normalize_number_fields(pb)
|
||||||
|
|
||||||
|
a_str = text_format.MessageToString(a, descriptor_pool=pool)
|
||||||
|
b_str = text_format.MessageToString(b, descriptor_pool=pool)
|
||||||
|
|
||||||
|
# Some Python versions would perform regular diff instead of multi-line
|
||||||
|
# diff if string is longer than 2**16. We substitute this behavior
|
||||||
|
# with a call to unified_diff instead to have easier-to-read diffs.
|
||||||
|
# For context, see: https://bugs.python.org/issue11763.
|
||||||
|
if len(a_str) < 2**16 and len(b_str) < 2**16:
|
||||||
|
self.assertMultiLineEqual(a_str, b_str, msg=msg)
|
||||||
|
else:
|
||||||
|
diff = ''.join(
|
||||||
|
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
|
||||||
|
if diff:
|
||||||
|
self.fail('%s :\n%s' % (msg, diff))
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_number_fields(pb):
|
||||||
|
"""Normalizes types and precisions of number fields in a protocol buffer.
|
||||||
|
Due to subtleties in the python protocol buffer implementation, it is possible
|
||||||
|
for values to have different types and precision depending on whether they
|
||||||
|
were set and retrieved directly or deserialized from a protobuf. This function
|
||||||
|
normalizes integer values to ints and longs based on width, 32-bit floats to
|
||||||
|
five digits of precision to account for python always storing them as 64-bit,
|
||||||
|
and ensures doubles are floating point for when they're set to integers.
|
||||||
|
Modifies pb in place. Recurses into nested objects.
|
||||||
|
Args:
|
||||||
|
pb: proto2 message.
|
||||||
|
Returns:
|
||||||
|
the given pb, modified in place.
|
||||||
|
"""
|
||||||
|
for desc, values in pb.ListFields():
|
||||||
|
is_repeated = True
|
||||||
|
if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED:
|
||||||
|
is_repeated = False
|
||||||
|
values = [values]
|
||||||
|
|
||||||
|
normalized_values = None
|
||||||
|
|
||||||
|
# We force 32-bit values to int and 64-bit values to long to make
|
||||||
|
# alternate implementations where the distinction is more significant
|
||||||
|
# (e.g. the C++ implementation) simpler.
|
||||||
|
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
|
||||||
|
descriptor.FieldDescriptor.TYPE_UINT64,
|
||||||
|
descriptor.FieldDescriptor.TYPE_SINT64):
|
||||||
|
normalized_values = [int(x) for x in values]
|
||||||
|
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
|
||||||
|
descriptor.FieldDescriptor.TYPE_UINT32,
|
||||||
|
descriptor.FieldDescriptor.TYPE_SINT32,
|
||||||
|
descriptor.FieldDescriptor.TYPE_ENUM):
|
||||||
|
normalized_values = [int(x) for x in values]
|
||||||
|
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
|
||||||
|
normalized_values = [round(x, 6) for x in values]
|
||||||
|
elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
|
||||||
|
normalized_values = [round(float(x), 7) for x in values]
|
||||||
|
|
||||||
|
if normalized_values is not None:
|
||||||
|
if is_repeated:
|
||||||
|
pb.ClearField(desc.name)
|
||||||
|
getattr(pb, desc.name).extend(normalized_values)
|
||||||
|
else:
|
||||||
|
setattr(pb, desc.name, normalized_values[0])
|
||||||
|
|
||||||
|
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
|
||||||
|
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
|
||||||
|
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
|
||||||
|
desc.message_type.has_options and
|
||||||
|
desc.message_type.GetOptions().map_entry):
|
||||||
|
# This is a map, only recurse if the values have a message type.
|
||||||
|
if (desc.message_type.fields_by_number[2].type ==
|
||||||
|
descriptor.FieldDescriptor.TYPE_MESSAGE):
|
||||||
|
for v in six.itervalues(values):
|
||||||
|
_normalize_number_fields(v)
|
||||||
|
else:
|
||||||
|
for v in values:
|
||||||
|
# recursive step
|
||||||
|
_normalize_number_fields(v)
|
||||||
|
|
||||||
|
return pb
|
||||||
|
|
|
@ -51,8 +51,8 @@ _MAX_RESULTS = 3
|
||||||
|
|
||||||
|
|
||||||
# TODO: Port assertProtoEquals
|
# TODO: Port assertProtoEquals
|
||||||
def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument
|
def _assert_proto_equals(self, expected, actual): # pylint: disable=unused-argument
|
||||||
pass
|
test_utils.assertProtoEqual(self, expected, actual)
|
||||||
|
|
||||||
|
|
||||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
|
@ -186,7 +186,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
_assert_proto_equals(image_result.to_pb2(),
|
_assert_proto_equals(self, image_result.to_pb2(),
|
||||||
expected_classification_result.to_pb2())
|
expected_classification_result.to_pb2())
|
||||||
# Closes the classifier explicitly when the classifier is not used in
|
# Closes the classifier explicitly when the classifier is not used in
|
||||||
# a context.
|
# a context.
|
||||||
|
@ -214,7 +214,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
_assert_proto_equals(image_result.to_pb2(),
|
_assert_proto_equals(self, image_result.to_pb2(),
|
||||||
expected_classification_result.to_pb2())
|
expected_classification_result.to_pb2())
|
||||||
|
|
||||||
def test_classify_succeeds_with_region_of_interest(self):
|
def test_classify_succeeds_with_region_of_interest(self):
|
||||||
|
@ -232,7 +232,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(test_image, roi)
|
image_result = classifier.classify(test_image, roi)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
_assert_proto_equals(image_result.to_pb2(),
|
_assert_proto_equals(self, image_result.to_pb2(),
|
||||||
_generate_soccer_ball_results(0).to_pb2())
|
_generate_soccer_ball_results(0).to_pb2())
|
||||||
|
|
||||||
def test_score_threshold_option(self):
|
def test_score_threshold_option(self):
|
||||||
|
@ -401,7 +401,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classification_result = classifier.classify_for_video(
|
classification_result = classifier.classify_for_video(
|
||||||
self.test_image, timestamp)
|
self.test_image, timestamp)
|
||||||
_assert_proto_equals(classification_result.to_pb2(),
|
_assert_proto_equals(self, classification_result.to_pb2(),
|
||||||
_generate_burger_results(timestamp).to_pb2())
|
_generate_burger_results(timestamp).to_pb2())
|
||||||
|
|
||||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||||
|
@ -420,8 +420,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classification_result = classifier.classify_for_video(
|
classification_result = classifier.classify_for_video(
|
||||||
test_image, timestamp, roi)
|
test_image, timestamp, roi)
|
||||||
self.assertEqual(classification_result,
|
_assert_proto_equals(self, classification_result.to_pb2(),
|
||||||
_generate_soccer_ball_results(timestamp))
|
_generate_soccer_ball_results(timestamp).to_pb2())
|
||||||
|
|
||||||
def test_calling_classify_in_live_stream_mode(self):
|
def test_calling_classify_in_live_stream_mode(self):
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
|
@ -463,7 +463,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
_assert_proto_equals(result.to_pb2(),
|
_assert_proto_equals(self, result.to_pb2(),
|
||||||
expected_result_fn(timestamp_ms).to_pb2())
|
expected_result_fn(timestamp_ms).to_pb2())
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.array_equal(output_image.numpy_view(),
|
np.array_equal(output_image.numpy_view(),
|
||||||
|
@ -493,7 +493,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
_assert_proto_equals(result.to_pb2(),
|
_assert_proto_equals(self, result.to_pb2(),
|
||||||
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
||||||
self.assertEqual(output_image.width, test_image.width)
|
self.assertEqual(output_image.width, test_image.width)
|
||||||
self.assertEqual(output_image.height, test_image.height)
|
self.assertEqual(output_image.height, test_image.height)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user