From 1689112b23fc6038114a143baf0253e0b6c043c6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 14:02:21 -0800 Subject: [PATCH] Improve model_util_test code. PiperOrigin-RevId: 488752497 --- .../model_maker/python/core/utils/model_util_test.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index bef9c8a97..05c6ffe3f 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from typing import Optional from absl.testing import parameterized import tensorflow as tf @@ -76,8 +77,10 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]), expected_steps_per_epoch=2)) - def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data, - expected_steps_per_epoch): + def test_get_steps_per_epoch(self, steps_per_epoch: Optional[int], + batch_size: Optional[int], + train_data: Optional[tf.data.Dataset], + expected_steps_per_epoch: int): estimated_steps_per_epoch = model_util.get_steps_per_epoch( steps_per_epoch=steps_per_epoch, batch_size=batch_size, @@ -130,7 +133,9 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): testcase_name='float16_quantize', config=quantization.QuantizationConfig.for_float16(), model_size=1468)) - def test_convert_to_tflite_quantized(self, config, model_size): + def test_convert_to_tflite_quantized(self, + config: quantization.QuantizationConfig, + model_size: int): input_dim = 16 num_classes = 2 max_input_value = 5 @@ -157,5 +162,6 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): test_util.test_tflite_file( keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) + if __name__ == '__main__': tf.test.main()