diff --git a/mediapipe/tasks/python/test/test_utils.py b/mediapipe/tasks/python/test/test_utils.py index a9b961944..b2387eb2c 100644 --- a/mediapipe/tasks/python/test/test_utils.py +++ b/mediapipe/tasks/python/test/test_utils.py @@ -61,10 +61,13 @@ def create_calibration_file(file_dir: str, def assertProtoEqual(self, a, b, check_initialized=True, - normalize_numbers=False, msg=None): - """Fails with a useful error if a and b aren't equal. - Comparison of repeated fields matches the semantics of + normalize_numbers=True, msg=None): + """assertProtoEqual() is useful for unit tests. It produces much more helpful + output than assertEqual() for proto2 messages. Fails with a useful error if a + and b aren't equal. Comparison of repeated fields matches the semantics of unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter. + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L73 + Args: self: absltest.testing.parameterized.TestCase a: proto2 PB instance, or text string representing one. @@ -112,6 +115,8 @@ def _normalize_number_fields(pb): five digits of precision to account for python always storing them as 64-bit, and ensures doubles are floating point for when they're set to integers. Modifies pb in place. Recurses into nested objects. + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118 + Args: pb: proto2 message. Returns: diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 4ebc113bd..9b8c5e94e 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -50,11 +50,6 @@ _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 -# TODO: Port assertProtoEquals -def _assert_proto_equals(self, expected, actual): # pylint: disable=unused-argument - test_utils.assertProtoEqual(self, expected, actual) - - def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: return _ClassificationResult(classifications=[ _Classifications( @@ -74,22 +69,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: categories=[ _Category( index=934, - score=0.7939587831497192, + score=0.793959, display_name='', category_name='cheeseburger'), _Category( index=932, - score=0.02739289402961731, + score=0.0273929, display_name='', category_name='bagel'), _Category( index=925, - score=0.01934075355529785, + score=0.0193408, display_name='', category_name='guacamole'), _Category( index=963, - score=0.006327860057353973, + score=0.00632786, display_name='', category_name='meat loaf') ], @@ -108,7 +103,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult: categories=[ _Category( index=806, - score=0.9965274930000305, + score=0.996527, display_name='', category_name='soccer ball') ], @@ -186,8 +181,8 @@ class ImageClassifierTest(parameterized.TestCase): # Performs image classification on the input. image_result = classifier.classify(self.test_image) # Comparing results. - _assert_proto_equals(self, image_result.to_pb2(), - expected_classification_result.to_pb2()) + test_utils.assertProtoEqual(self, image_result.to_pb2(), + expected_classification_result.to_pb2()) # Closes the classifier explicitly when the classifier is not used in # a context. classifier.close() @@ -214,8 +209,8 @@ class ImageClassifierTest(parameterized.TestCase): # Performs image classification on the input. image_result = classifier.classify(self.test_image) # Comparing results. - _assert_proto_equals(self, image_result.to_pb2(), - expected_classification_result.to_pb2()) + test_utils.assertProtoEqual(self, image_result.to_pb2(), + expected_classification_result.to_pb2()) def test_classify_succeeds_with_region_of_interest(self): base_options = _BaseOptions(model_asset_path=self.model_path) @@ -232,8 +227,8 @@ class ImageClassifierTest(parameterized.TestCase): # Performs image classification on the input. image_result = classifier.classify(test_image, roi) # Comparing results. - _assert_proto_equals(self, image_result.to_pb2(), - _generate_soccer_ball_results(0).to_pb2()) + test_utils.assertProtoEqual(self, image_result.to_pb2(), + _generate_soccer_ball_results(0).to_pb2()) def test_score_threshold_option(self): custom_classifier_options = _ClassifierOptions( @@ -401,8 +396,9 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( self.test_image, timestamp) - _assert_proto_equals(self, classification_result.to_pb2(), - _generate_burger_results(timestamp).to_pb2()) + test_utils.assertProtoEqual( + self, classification_result.to_pb2(), + _generate_burger_results(timestamp).to_pb2()) def test_classify_for_video_succeeds_with_region_of_interest(self): custom_classifier_options = _ClassifierOptions(max_results=1) @@ -420,8 +416,9 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( test_image, timestamp, roi) - _assert_proto_equals(self, classification_result.to_pb2(), - _generate_soccer_ball_results(timestamp).to_pb2()) + test_utils.assertProtoEqual( + self, classification_result.to_pb2(), + _generate_soccer_ball_results(timestamp).to_pb2()) def test_calling_classify_in_live_stream_mode(self): options = _ImageClassifierOptions( @@ -463,8 +460,8 @@ class ImageClassifierTest(parameterized.TestCase): def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): - _assert_proto_equals(self, result.to_pb2(), - expected_result_fn(timestamp_ms).to_pb2()) + test_utils.assertProtoEqual( + self, result.to_pb2(), expected_result_fn(timestamp_ms).to_pb2()) self.assertTrue( np.array_equal(output_image.numpy_view(), self.test_image.numpy_view())) @@ -493,8 +490,9 @@ class ImageClassifierTest(parameterized.TestCase): def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): - _assert_proto_equals(self, result.to_pb2(), - _generate_soccer_ball_results(timestamp_ms).to_pb2()) + test_utils.assertProtoEqual( + self, result.to_pb2(), + _generate_soccer_ball_results(timestamp_ms).to_pb2()) self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.height, test_image.height) self.assertLess(observed_timestamp_ms, timestamp_ms)