Merge pull request #3800 from kinaryml:python-test-proto-equals
PiperOrigin-RevId: 485340924
This commit is contained in:
		
						commit
						6e0397b226
					
				| 
						 | 
					@ -27,5 +27,8 @@ py_library(
 | 
				
			||||||
        "//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
 | 
					        "//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
 | 
				
			||||||
        "//mediapipe/tasks:internal",
 | 
					        "//mediapipe/tasks:internal",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    deps = ["//mediapipe/python:_framework_bindings"],
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/python:_framework_bindings",
 | 
				
			||||||
 | 
					        "@com_google_protobuf//:protobuf_python",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,9 +13,15 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
"""Test util for MediaPipe Tasks."""
 | 
					"""Test util for MediaPipe Tasks."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import difflib
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from absl import flags
 | 
					from absl import flags
 | 
				
			||||||
 | 
					import six
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from google.protobuf import descriptor
 | 
				
			||||||
 | 
					from google.protobuf import descriptor_pool
 | 
				
			||||||
 | 
					from google.protobuf import text_format
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.python._framework_bindings import image as image_module
 | 
					from mediapipe.python._framework_bindings import image as image_module
 | 
				
			||||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
 | 
					from mediapipe.python._framework_bindings import image_frame as image_frame_module
 | 
				
			||||||
| 
						 | 
					@ -53,3 +59,126 @@ def create_calibration_file(file_dir: str,
 | 
				
			||||||
  with open(calibration_file, mode="w") as file:
 | 
					  with open(calibration_file, mode="w") as file:
 | 
				
			||||||
    file.write(content)
 | 
					    file.write(content)
 | 
				
			||||||
  return calibration_file
 | 
					  return calibration_file
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def assert_proto_equals(self,
 | 
				
			||||||
 | 
					                        a,
 | 
				
			||||||
 | 
					                        b,
 | 
				
			||||||
 | 
					                        check_initialized=True,
 | 
				
			||||||
 | 
					                        normalize_numbers=True,
 | 
				
			||||||
 | 
					                        msg=None):
 | 
				
			||||||
 | 
					  """assert_proto_equals() is useful for unit tests.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  It produces much more helpful output than assertEqual() for proto2 messages.
 | 
				
			||||||
 | 
					  Fails with a useful error if a and b aren't equal. Comparison of repeated
 | 
				
			||||||
 | 
					  fields matches the semantics of unittest.TestCase.assertEqual(), ie order and
 | 
				
			||||||
 | 
					  extra duplicates fields matter.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  This is a fork of https://github.com/tensorflow/tensorflow/blob/
 | 
				
			||||||
 | 
					  master/tensorflow/python/util/protobuf/compare.py#L73. We use slightly
 | 
				
			||||||
 | 
					  different rounding cutoffs to support Mac usage.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Args:
 | 
				
			||||||
 | 
					    self: absltest.testing.parameterized.TestCase
 | 
				
			||||||
 | 
					    a: proto2 PB instance, or text string representing one.
 | 
				
			||||||
 | 
					    b: proto2 PB instance -- message.Message or subclass thereof.
 | 
				
			||||||
 | 
					    check_initialized: boolean, whether to fail if either a or b isn't
 | 
				
			||||||
 | 
					      initialized.
 | 
				
			||||||
 | 
					    normalize_numbers: boolean, whether to normalize types and precision of
 | 
				
			||||||
 | 
					      numbers before comparison.
 | 
				
			||||||
 | 
					    msg: if specified, is used as the error message on failure.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  pool = descriptor_pool.Default()
 | 
				
			||||||
 | 
					  if isinstance(a, six.string_types):
 | 
				
			||||||
 | 
					    a = text_format.Parse(a, b.__class__(), descriptor_pool=pool)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for pb in a, b:
 | 
				
			||||||
 | 
					    if check_initialized:
 | 
				
			||||||
 | 
					      errors = pb.FindInitializationErrors()
 | 
				
			||||||
 | 
					      if errors:
 | 
				
			||||||
 | 
					        self.fail("Initialization errors: %s\n%s" % (errors, pb))
 | 
				
			||||||
 | 
					    if normalize_numbers:
 | 
				
			||||||
 | 
					      _normalize_number_fields(pb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  a_str = text_format.MessageToString(a, descriptor_pool=pool)
 | 
				
			||||||
 | 
					  b_str = text_format.MessageToString(b, descriptor_pool=pool)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # Some Python versions would perform regular diff instead of multi-line
 | 
				
			||||||
 | 
					  # diff if string is longer than 2**16. We substitute this behavior
 | 
				
			||||||
 | 
					  # with a call to unified_diff instead to have easier-to-read diffs.
 | 
				
			||||||
 | 
					  # For context, see: https://bugs.python.org/issue11763.
 | 
				
			||||||
 | 
					  if len(a_str) < 2**16 and len(b_str) < 2**16:
 | 
				
			||||||
 | 
					    self.assertMultiLineEqual(a_str, b_str, msg=msg)
 | 
				
			||||||
 | 
					  else:
 | 
				
			||||||
 | 
					    diff = "".join(
 | 
				
			||||||
 | 
					        difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
 | 
				
			||||||
 | 
					    if diff:
 | 
				
			||||||
 | 
					      self.fail("%s :\n%s" % (msg, diff))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _normalize_number_fields(pb):
 | 
				
			||||||
 | 
					  """Normalizes types and precisions of number fields in a protocol buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Due to subtleties in the python protocol buffer implementation, it is possible
 | 
				
			||||||
 | 
					  for values to have different types and precision depending on whether they
 | 
				
			||||||
 | 
					  were set and retrieved directly or deserialized from a protobuf. This function
 | 
				
			||||||
 | 
					  normalizes integer values to ints and longs based on width, 32-bit floats to
 | 
				
			||||||
 | 
					  five digits of precision to account for python always storing them as 64-bit,
 | 
				
			||||||
 | 
					  and ensures doubles are floating point for when they're set to integers.
 | 
				
			||||||
 | 
					  Modifies pb in place. Recurses into nested objects. https://github.com/tensorf
 | 
				
			||||||
 | 
					  low/tensorflow/blob/master/tensorflow/python/util/protobuf/compare.py#L118
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Args:
 | 
				
			||||||
 | 
					    pb: proto2 message.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Returns:
 | 
				
			||||||
 | 
					    the given pb, modified in place.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  for desc, values in pb.ListFields():
 | 
				
			||||||
 | 
					    is_repeated = True
 | 
				
			||||||
 | 
					    if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED:
 | 
				
			||||||
 | 
					      is_repeated = False
 | 
				
			||||||
 | 
					      values = [values]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    normalized_values = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # We force 32-bit values to int and 64-bit values to long to make
 | 
				
			||||||
 | 
					    # alternate implementations where the distinction is more significant
 | 
				
			||||||
 | 
					    # (e.g. the C++ implementation) simpler.
 | 
				
			||||||
 | 
					    if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
 | 
				
			||||||
 | 
					                     descriptor.FieldDescriptor.TYPE_UINT64,
 | 
				
			||||||
 | 
					                     descriptor.FieldDescriptor.TYPE_SINT64):
 | 
				
			||||||
 | 
					      normalized_values = [int(x) for x in values]
 | 
				
			||||||
 | 
					    elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
 | 
				
			||||||
 | 
					                       descriptor.FieldDescriptor.TYPE_UINT32,
 | 
				
			||||||
 | 
					                       descriptor.FieldDescriptor.TYPE_SINT32,
 | 
				
			||||||
 | 
					                       descriptor.FieldDescriptor.TYPE_ENUM):
 | 
				
			||||||
 | 
					      normalized_values = [int(x) for x in values]
 | 
				
			||||||
 | 
					    elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
 | 
				
			||||||
 | 
					      normalized_values = [round(x, 5) for x in values]
 | 
				
			||||||
 | 
					    elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
 | 
				
			||||||
 | 
					      normalized_values = [round(float(x), 6) for x in values]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if normalized_values is not None:
 | 
				
			||||||
 | 
					      if is_repeated:
 | 
				
			||||||
 | 
					        pb.ClearField(desc.name)
 | 
				
			||||||
 | 
					        getattr(pb, desc.name).extend(normalized_values)
 | 
				
			||||||
 | 
					      else:
 | 
				
			||||||
 | 
					        setattr(pb, desc.name, normalized_values[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
 | 
				
			||||||
 | 
					        desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
 | 
				
			||||||
 | 
					      if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
 | 
				
			||||||
 | 
					          desc.message_type.has_options and
 | 
				
			||||||
 | 
					          desc.message_type.GetOptions().map_entry):
 | 
				
			||||||
 | 
					        # This is a map, only recurse if the values have a message type.
 | 
				
			||||||
 | 
					        if (desc.message_type.fields_by_number[2].type ==
 | 
				
			||||||
 | 
					            descriptor.FieldDescriptor.TYPE_MESSAGE):
 | 
				
			||||||
 | 
					          for v in six.itervalues(values):
 | 
				
			||||||
 | 
					            _normalize_number_fields(v)
 | 
				
			||||||
 | 
					      else:
 | 
				
			||||||
 | 
					        for v in values:
 | 
				
			||||||
 | 
					          # recursive step
 | 
				
			||||||
 | 
					          _normalize_number_fields(v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return pb
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,11 +50,6 @@ _SCORE_THRESHOLD = 0.5
 | 
				
			||||||
_MAX_RESULTS = 3
 | 
					_MAX_RESULTS = 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: Port assertProtoEquals
 | 
					 | 
				
			||||||
def _assert_proto_equals(expected, actual):  # pylint: disable=unused-argument
 | 
					 | 
				
			||||||
  pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
 | 
					def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
 | 
				
			||||||
  return _ClassificationResult(classifications=[
 | 
					  return _ClassificationResult(classifications=[
 | 
				
			||||||
      _Classifications(
 | 
					      _Classifications(
 | 
				
			||||||
| 
						 | 
					@ -74,22 +69,22 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
 | 
				
			||||||
                  categories=[
 | 
					                  categories=[
 | 
				
			||||||
                      _Category(
 | 
					                      _Category(
 | 
				
			||||||
                          index=934,
 | 
					                          index=934,
 | 
				
			||||||
                          score=0.7939587831497192,
 | 
					                          score=0.793959,
 | 
				
			||||||
                          display_name='',
 | 
					                          display_name='',
 | 
				
			||||||
                          category_name='cheeseburger'),
 | 
					                          category_name='cheeseburger'),
 | 
				
			||||||
                      _Category(
 | 
					                      _Category(
 | 
				
			||||||
                          index=932,
 | 
					                          index=932,
 | 
				
			||||||
                          score=0.02739289402961731,
 | 
					                          score=0.0273929,
 | 
				
			||||||
                          display_name='',
 | 
					                          display_name='',
 | 
				
			||||||
                          category_name='bagel'),
 | 
					                          category_name='bagel'),
 | 
				
			||||||
                      _Category(
 | 
					                      _Category(
 | 
				
			||||||
                          index=925,
 | 
					                          index=925,
 | 
				
			||||||
                          score=0.01934075355529785,
 | 
					                          score=0.0193408,
 | 
				
			||||||
                          display_name='',
 | 
					                          display_name='',
 | 
				
			||||||
                          category_name='guacamole'),
 | 
					                          category_name='guacamole'),
 | 
				
			||||||
                      _Category(
 | 
					                      _Category(
 | 
				
			||||||
                          index=963,
 | 
					                          index=963,
 | 
				
			||||||
                          score=0.006327860057353973,
 | 
					                          score=0.00632786,
 | 
				
			||||||
                          display_name='',
 | 
					                          display_name='',
 | 
				
			||||||
                          category_name='meat loaf')
 | 
					                          category_name='meat loaf')
 | 
				
			||||||
                  ],
 | 
					                  ],
 | 
				
			||||||
| 
						 | 
					@ -108,7 +103,7 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
 | 
				
			||||||
                  categories=[
 | 
					                  categories=[
 | 
				
			||||||
                      _Category(
 | 
					                      _Category(
 | 
				
			||||||
                          index=806,
 | 
					                          index=806,
 | 
				
			||||||
                          score=0.9965274930000305,
 | 
					                          score=0.996527,
 | 
				
			||||||
                          display_name='',
 | 
					                          display_name='',
 | 
				
			||||||
                          category_name='soccer ball')
 | 
					                          category_name='soccer ball')
 | 
				
			||||||
                  ],
 | 
					                  ],
 | 
				
			||||||
| 
						 | 
					@ -186,7 +181,7 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
    # Performs image classification on the input.
 | 
					    # Performs image classification on the input.
 | 
				
			||||||
    image_result = classifier.classify(self.test_image)
 | 
					    image_result = classifier.classify(self.test_image)
 | 
				
			||||||
    # Comparing results.
 | 
					    # Comparing results.
 | 
				
			||||||
    _assert_proto_equals(image_result.to_pb2(),
 | 
					    test_utils.assert_proto_equals(self, image_result.to_pb2(),
 | 
				
			||||||
                                   expected_classification_result.to_pb2())
 | 
					                                   expected_classification_result.to_pb2())
 | 
				
			||||||
    # Closes the classifier explicitly when the classifier is not used in
 | 
					    # Closes the classifier explicitly when the classifier is not used in
 | 
				
			||||||
    # a context.
 | 
					    # a context.
 | 
				
			||||||
| 
						 | 
					@ -214,7 +209,7 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
      # Performs image classification on the input.
 | 
					      # Performs image classification on the input.
 | 
				
			||||||
      image_result = classifier.classify(self.test_image)
 | 
					      image_result = classifier.classify(self.test_image)
 | 
				
			||||||
      # Comparing results.
 | 
					      # Comparing results.
 | 
				
			||||||
      _assert_proto_equals(image_result.to_pb2(),
 | 
					      test_utils.assert_proto_equals(self, image_result.to_pb2(),
 | 
				
			||||||
                                     expected_classification_result.to_pb2())
 | 
					                                     expected_classification_result.to_pb2())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_classify_succeeds_with_region_of_interest(self):
 | 
					  def test_classify_succeeds_with_region_of_interest(self):
 | 
				
			||||||
| 
						 | 
					@ -232,7 +227,7 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
      # Performs image classification on the input.
 | 
					      # Performs image classification on the input.
 | 
				
			||||||
      image_result = classifier.classify(test_image, roi)
 | 
					      image_result = classifier.classify(test_image, roi)
 | 
				
			||||||
      # Comparing results.
 | 
					      # Comparing results.
 | 
				
			||||||
      _assert_proto_equals(image_result.to_pb2(),
 | 
					      test_utils.assert_proto_equals(self, image_result.to_pb2(),
 | 
				
			||||||
                                     _generate_soccer_ball_results(0).to_pb2())
 | 
					                                     _generate_soccer_ball_results(0).to_pb2())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_score_threshold_option(self):
 | 
					  def test_score_threshold_option(self):
 | 
				
			||||||
| 
						 | 
					@ -401,7 +396,8 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
      for timestamp in range(0, 300, 30):
 | 
					      for timestamp in range(0, 300, 30):
 | 
				
			||||||
        classification_result = classifier.classify_for_video(
 | 
					        classification_result = classifier.classify_for_video(
 | 
				
			||||||
            self.test_image, timestamp)
 | 
					            self.test_image, timestamp)
 | 
				
			||||||
        _assert_proto_equals(classification_result.to_pb2(),
 | 
					        test_utils.assert_proto_equals(
 | 
				
			||||||
 | 
					            self, classification_result.to_pb2(),
 | 
				
			||||||
            _generate_burger_results(timestamp).to_pb2())
 | 
					            _generate_burger_results(timestamp).to_pb2())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_classify_for_video_succeeds_with_region_of_interest(self):
 | 
					  def test_classify_for_video_succeeds_with_region_of_interest(self):
 | 
				
			||||||
| 
						 | 
					@ -420,8 +416,9 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
      for timestamp in range(0, 300, 30):
 | 
					      for timestamp in range(0, 300, 30):
 | 
				
			||||||
        classification_result = classifier.classify_for_video(
 | 
					        classification_result = classifier.classify_for_video(
 | 
				
			||||||
            test_image, timestamp, roi)
 | 
					            test_image, timestamp, roi)
 | 
				
			||||||
        self.assertEqual(classification_result,
 | 
					        test_utils.assert_proto_equals(
 | 
				
			||||||
                         _generate_soccer_ball_results(timestamp))
 | 
					            self, classification_result.to_pb2(),
 | 
				
			||||||
 | 
					            _generate_soccer_ball_results(timestamp).to_pb2())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_calling_classify_in_live_stream_mode(self):
 | 
					  def test_calling_classify_in_live_stream_mode(self):
 | 
				
			||||||
    options = _ImageClassifierOptions(
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
| 
						 | 
					@ -463,7 +460,7 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def check_result(result: _ClassificationResult, output_image: _Image,
 | 
					    def check_result(result: _ClassificationResult, output_image: _Image,
 | 
				
			||||||
                     timestamp_ms: int):
 | 
					                     timestamp_ms: int):
 | 
				
			||||||
      _assert_proto_equals(result.to_pb2(),
 | 
					      test_utils.assert_proto_equals(self, result.to_pb2(),
 | 
				
			||||||
                                     expected_result_fn(timestamp_ms).to_pb2())
 | 
					                                     expected_result_fn(timestamp_ms).to_pb2())
 | 
				
			||||||
      self.assertTrue(
 | 
					      self.assertTrue(
 | 
				
			||||||
          np.array_equal(output_image.numpy_view(),
 | 
					          np.array_equal(output_image.numpy_view(),
 | 
				
			||||||
| 
						 | 
					@ -493,7 +490,8 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def check_result(result: _ClassificationResult, output_image: _Image,
 | 
					    def check_result(result: _ClassificationResult, output_image: _Image,
 | 
				
			||||||
                     timestamp_ms: int):
 | 
					                     timestamp_ms: int):
 | 
				
			||||||
      _assert_proto_equals(result.to_pb2(),
 | 
					      test_utils.assert_proto_equals(
 | 
				
			||||||
 | 
					          self, result.to_pb2(),
 | 
				
			||||||
          _generate_soccer_ball_results(timestamp_ms).to_pb2())
 | 
					          _generate_soccer_ball_results(timestamp_ms).to_pb2())
 | 
				
			||||||
      self.assertEqual(output_image.width, test_image.width)
 | 
					      self.assertEqual(output_image.width, test_image.width)
 | 
				
			||||||
      self.assertEqual(output_image.height, test_image.height)
 | 
					      self.assertEqual(output_image.height, test_image.height)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user