diff --git a/mediapipe/model_maker/python/core/tasks/custom_model.py b/mediapipe/model_maker/python/core/tasks/custom_model.py index 2cea4e0a1..66d1494db 100644 --- a/mediapipe/model_maker/python/core/tasks/custom_model.py +++ b/mediapipe/model_maker/python/core/tasks/custom_model.py @@ -21,8 +21,6 @@ import abc import os from typing import Any, Callable, Optional -# Dependency imports - import tensorflow as tf from mediapipe.model_maker.python.core.data import dataset @@ -77,9 +75,9 @@ class CustomModel(abc.ABC): tflite_filepath = os.path.join(export_dir, tflite_filename) # TODO: Populate metadata to the exported TFLite model. model_util.export_tflite( - self._model, - tflite_filepath, - quantization_config, + model=self._model, + tflite_filepath=tflite_filepath, + quantization_config=quantization_config, preprocess=preprocess) tf.compat.v1.logging.info( 'TensorFlow Lite model exported successfully: %s' % tflite_filepath) diff --git a/mediapipe/model_maker/python/core/tasks/custom_model_test.py b/mediapipe/model_maker/python/core/tasks/custom_model_test.py index e693e1275..ad77d4ecd 100644 --- a/mediapipe/model_maker/python/core/tasks/custom_model_test.py +++ b/mediapipe/model_maker/python/core/tasks/custom_model_test.py @@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase): def setUp(self): super(CustomModelTest, self).setUp() - self.model = MockCustomModel(model_spec=None, shuffle=False) - self.model._model = test_util.build_model(input_shape=[4], num_classes=2) + self._model = MockCustomModel(model_spec=None, shuffle=False) + self._model._model = test_util.build_model(input_shape=[4], num_classes=2) def _check_nonempty_file(self, filepath): self.assertTrue(os.path.isfile(filepath)) @@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase): def test_export_tflite(self): export_path = os.path.join(self.get_temp_dir(), 'export/') - self.model.export_tflite(export_dir=export_path) + self._model.export_tflite(export_dir=export_path) self._check_nonempty_file(os.path.join(export_path, 'model.tflite')) if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 0899a9b1a..e1228eb6e 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -104,8 +104,8 @@ def export_tflite( quantization_config: Configuration for post-training quantization. supported_ops: A list of supported ops in the converted TFLite file. preprocess: A callable to preprocess the representative dataset for - quantization. The callable takes three arguments in order: feature, - label, and is_training. + quantization. The callable takes three arguments in order: feature, label, + and is_training. """ if tflite_filepath is None: raise ValueError( diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index ce31c1877..35b52eb75 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): model = test_util.build_model(input_shape=[input_dim], num_classes=2) tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') model_util.export_tflite(model, tflite_file) - self._test_tflite(model, tflite_file, input_dim) + test_util.test_tflite( + keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) @parameterized.named_parameters( dict( @@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): input_dim = 16 num_classes = 2 max_input_value = 5 - model = test_util.build_model([input_dim], num_classes) + model = test_util.build_model( + input_shape=[input_dim], num_classes=num_classes) tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite') - model_util.export_tflite(model, tflite_file, config) - self._test_tflite( - model, tflite_file, input_dim, max_input_value, atol=1e-00) - self.assertNear(os.path.getsize(tflite_file), model_size, 300) - - def _test_tflite(self, - keras_model: tf.keras.Model, - tflite_model_file: str, - input_dim: int, - max_input_value: int = 1000, - atol: float = 1e-04): - random_input = test_util.create_random_sample( - size=[1, input_dim], high=max_input_value) - random_input = tf.convert_to_tensor(random_input) - + model_util.export_tflite( + model=model, tflite_filepath=tflite_file, quantization_config=config) self.assertTrue( - test_util.is_same_output( - tflite_model_file, keras_model, random_input, atol=atol)) + test_util.test_tflite( + keras_model=model, + tflite_file=tflite_file, + size=[1, input_dim], + high=max_input_value, + atol=1e-00)) + self.assertNear(os.path.getsize(tflite_file), model_size, 300) if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/core/utils/test_util.py b/mediapipe/model_maker/python/core/utils/test_util.py index cac2a0e1f..b402d3793 100644 --- a/mediapipe/model_maker/python/core/utils/test_util.py +++ b/mediapipe/model_maker/python/core/utils/test_util.py @@ -92,3 +92,32 @@ def is_same_output(tflite_file: str, keras_output = keras_model.predict_on_batch(input_tensors) return np.allclose(lite_output, keras_output, atol=atol) + + +def test_tflite(keras_model: tf.keras.Model, + tflite_file: str, + size: Union[int, List[int]], + high: float = 1, + atol: float = 1e-04) -> bool: + """Verifies if the output of TFLite model and TF Keras model are identical. + + Args: + keras_model: Input TensorFlow Keras model. + tflite_file: Input TFLite model file. + size: Size of the input tesnor. + high: Higher boundary of the values in input tensors. + atol: Absolute tolerance of the difference between the outputs of Keras + model and TFLite model. + + Returns: + True if the output of TFLite model and TF Keras model are identical. + Otherwise, False. + """ + random_input = create_random_sample(size=size, high=high) + random_input = tf.convert_to_tensor(random_input) + + return is_same_output( + tflite_file=tflite_file, + keras_model=keras_model, + input_tensors=random_input, + atol=atol)