Improve model_util_test code.

PiperOrigin-RevId: 488752497
This commit is contained in:
MediaPipe Team 2022-11-15 14:02:21 -08:00 committed by Copybara-Service
parent a94564540b
commit 1689112b23

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
from typing import Optional
from absl.testing import parameterized
import tensorflow as tf
@ -76,8 +77,10 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]]),
expected_steps_per_epoch=2))
def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data,
expected_steps_per_epoch):
def test_get_steps_per_epoch(self, steps_per_epoch: Optional[int],
batch_size: Optional[int],
train_data: Optional[tf.data.Dataset],
expected_steps_per_epoch: int):
estimated_steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=steps_per_epoch,
batch_size=batch_size,
@ -130,7 +133,9 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
testcase_name='float16_quantize',
config=quantization.QuantizationConfig.for_float16(),
model_size=1468))
def test_convert_to_tflite_quantized(self, config, model_size):
def test_convert_to_tflite_quantized(self,
config: quantization.QuantizationConfig,
model_size: int):
input_dim = 16
num_classes = 2
max_input_value = 5
@ -157,5 +162,6 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
test_util.test_tflite_file(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
if __name__ == '__main__':
tf.test.main()