diff --git a/mediapipe/tasks/python/test/BUILD b/mediapipe/tasks/python/test/BUILD index 92c5f4038..5ad057983 100644 --- a/mediapipe/tasks/python/test/BUILD +++ b/mediapipe/tasks/python/test/BUILD @@ -27,5 +27,8 @@ py_library( "//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__", "//mediapipe/tasks:internal", ], - deps = ["//mediapipe/python:_framework_bindings"], + deps = [ + "//mediapipe/python:_framework_bindings", + "@com_google_protobuf//:protobuf_python", + ], ) diff --git a/mediapipe/tasks/python/test/test_utils.py b/mediapipe/tasks/python/test/test_utils.py index b428f8302..d2e76c57b 100644 --- a/mediapipe/tasks/python/test/test_utils.py +++ b/mediapipe/tasks/python/test/test_utils.py @@ -13,9 +13,15 @@ # limitations under the License. """Test util for MediaPipe Tasks.""" +import difflib import os from absl import flags +import six + +from google.protobuf import descriptor +from google.protobuf import descriptor_pool +from google.protobuf import text_format from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image_frame as image_frame_module @@ -53,3 +59,126 @@ def create_calibration_file(file_dir: str, with open(calibration_file, mode="w") as file: file.write(content) return calibration_file + + +def assert_proto_equals(self, + a, + b, + check_initialized=True, + normalize_numbers=True, + msg=None): + """assert_proto_equals() is useful for unit tests. + + It produces much more helpful output than assertEqual() for proto2 messages. + Fails with a useful error if a and b aren't equal. Comparison of repeated + fields matches the semantics of unittest.TestCase.assertEqual(), ie order and + extra duplicates fields matter. + + This is a fork of https://github.com/tensorflow/tensorflow/blob/ + master/tensorflow/python/util/protobuf/compare.py#L73. We use slightly + different rounding cutoffs to support Mac usage. + + Args: + self: absltest.testing.parameterized.TestCase + a: proto2 PB instance, or text string representing one. + b: proto2 PB instance -- message.Message or subclass thereof. + check_initialized: boolean, whether to fail if either a or b isn't + initialized. + normalize_numbers: boolean, whether to normalize types and precision of + numbers before comparison. + msg: if specified, is used as the error message on failure. + """ + pool = descriptor_pool.Default() + if isinstance(a, six.string_types): + a = text_format.Parse(a, b.__class__(), descriptor_pool=pool) + + for pb in a, b: + if check_initialized: + errors = pb.FindInitializationErrors() + if errors: + self.fail("Initialization errors: %s\n%s" % (errors, pb)) + if normalize_numbers: + _normalize_number_fields(pb) + + a_str = text_format.MessageToString(a, descriptor_pool=pool) + b_str = text_format.MessageToString(b, descriptor_pool=pool) + + # Some Python versions would perform regular diff instead of multi-line + # diff if string is longer than 2**16. We substitute this behavior + # with a call to unified_diff instead to have easier-to-read diffs. + # For context, see: https://bugs.python.org/issue11763. + if len(a_str) < 2**16 and len(b_str) < 2**16: + self.assertMultiLineEqual(a_str, b_str, msg=msg) + else: + diff = "".join( + difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))) + if diff: + self.fail("%s :\n%s" % (msg, diff)) + + +def _normalize_number_fields(pb): + """Normalizes types and precisions of number fields in a protocol buffer. + + Due to subtleties in the python protocol buffer implementation, it is possible + for values to have different types and precision depending on whether they + were set and retrieved directly or deserialized from a protobuf. This function + normalizes integer values to ints and longs based on width, 32-bit floats to + five digits of precision to account for python always storing them as 64-bit, + and ensures doubles are floating point for when they're set to integers. + Modifies pb in place. Recurses into nested objects. https://github.com/tensorf + low/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118 + + Args: + pb: proto2 message. + + Returns: + the given pb, modified in place. + """ + for desc, values in pb.ListFields(): + is_repeated = True + if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED: + is_repeated = False + values = [values] + + normalized_values = None + + # We force 32-bit values to int and 64-bit values to long to make + # alternate implementations where the distinction is more significant + # (e.g. the C++ implementation) simpler. + if desc.type in (descriptor.FieldDescriptor.TYPE_INT64, + descriptor.FieldDescriptor.TYPE_UINT64, + descriptor.FieldDescriptor.TYPE_SINT64): + normalized_values = [int(x) for x in values] + elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32, + descriptor.FieldDescriptor.TYPE_UINT32, + descriptor.FieldDescriptor.TYPE_SINT32, + descriptor.FieldDescriptor.TYPE_ENUM): + normalized_values = [int(x) for x in values] + elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: + normalized_values = [round(x, 5) for x in values] + elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE: + normalized_values = [round(float(x), 6) for x in values] + + if normalized_values is not None: + if is_repeated: + pb.ClearField(desc.name) + getattr(pb, desc.name).extend(normalized_values) + else: + setattr(pb, desc.name, normalized_values[0]) + + if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or + desc.type == descriptor.FieldDescriptor.TYPE_GROUP): + if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + desc.message_type.has_options and + desc.message_type.GetOptions().map_entry): + # This is a map, only recurse if the values have a message type. + if (desc.message_type.fields_by_number[2].type == + descriptor.FieldDescriptor.TYPE_MESSAGE): + for v in six.itervalues(values): + _normalize_number_fields(v) + else: + for v in values: + # recursive step + _normalize_number_fields(v) + + return pb diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index afaf921a7..274bf3434 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -50,11 +50,6 @@ _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 -# TODO: Port assertProtoEquals -def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument - pass - - def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: return _ClassificationResult(classifications=[ _Classifications( @@ -74,22 +69,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: categories=[ _Category( index=934, - score=0.7939587831497192, + score=0.793959, display_name='', category_name='cheeseburger'), _Category( index=932, - score=0.02739289402961731, + score=0.0273929, display_name='', category_name='bagel'), _Category( index=925, - score=0.01934075355529785, + score=0.0193408, display_name='', category_name='guacamole'), _Category( index=963, - score=0.006327860057353973, + score=0.00632786, display_name='', category_name='meat loaf') ], @@ -108,7 +103,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult: categories=[ _Category( index=806, - score=0.9965274930000305, + score=0.996527, display_name='', category_name='soccer ball') ], @@ -186,8 +181,8 @@ class ImageClassifierTest(parameterized.TestCase): # Performs image classification on the input. image_result = classifier.classify(self.test_image) # Comparing results. - _assert_proto_equals(image_result.to_pb2(), - expected_classification_result.to_pb2()) + test_utils.assert_proto_equals(self, image_result.to_pb2(), + expected_classification_result.to_pb2()) # Closes the classifier explicitly when the classifier is not used in # a context. classifier.close() @@ -214,8 +209,8 @@ class ImageClassifierTest(parameterized.TestCase): # Performs image classification on the input. image_result = classifier.classify(self.test_image) # Comparing results. - _assert_proto_equals(image_result.to_pb2(), - expected_classification_result.to_pb2()) + test_utils.assert_proto_equals(self, image_result.to_pb2(), + expected_classification_result.to_pb2()) def test_classify_succeeds_with_region_of_interest(self): base_options = _BaseOptions(model_asset_path=self.model_path) @@ -232,8 +227,8 @@ class ImageClassifierTest(parameterized.TestCase): # Performs image classification on the input. image_result = classifier.classify(test_image, roi) # Comparing results. - _assert_proto_equals(image_result.to_pb2(), - _generate_soccer_ball_results(0).to_pb2()) + test_utils.assert_proto_equals(self, image_result.to_pb2(), + _generate_soccer_ball_results(0).to_pb2()) def test_score_threshold_option(self): custom_classifier_options = _ClassifierOptions( @@ -401,8 +396,9 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( self.test_image, timestamp) - _assert_proto_equals(classification_result.to_pb2(), - _generate_burger_results(timestamp).to_pb2()) + test_utils.assert_proto_equals( + self, classification_result.to_pb2(), + _generate_burger_results(timestamp).to_pb2()) def test_classify_for_video_succeeds_with_region_of_interest(self): custom_classifier_options = _ClassifierOptions(max_results=1) @@ -420,8 +416,9 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( test_image, timestamp, roi) - self.assertEqual(classification_result, - _generate_soccer_ball_results(timestamp)) + test_utils.assert_proto_equals( + self, classification_result.to_pb2(), + _generate_soccer_ball_results(timestamp).to_pb2()) def test_calling_classify_in_live_stream_mode(self): options = _ImageClassifierOptions( @@ -463,8 +460,8 @@ class ImageClassifierTest(parameterized.TestCase): def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): - _assert_proto_equals(result.to_pb2(), - expected_result_fn(timestamp_ms).to_pb2()) + test_utils.assert_proto_equals(self, result.to_pb2(), + expected_result_fn(timestamp_ms).to_pb2()) self.assertTrue( np.array_equal(output_image.numpy_view(), self.test_image.numpy_view())) @@ -493,8 +490,9 @@ class ImageClassifierTest(parameterized.TestCase): def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): - _assert_proto_equals(result.to_pb2(), - _generate_soccer_ball_results(timestamp_ms).to_pb2()) + test_utils.assert_proto_equals( + self, result.to_pb2(), + _generate_soccer_ball_results(timestamp_ms).to_pb2()) self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.height, test_image.height) self.assertLess(observed_timestamp_ms, timestamp_ms)