Improve model_util_test code.
PiperOrigin-RevId: 488752497
This commit is contained in:
parent
a94564540b
commit
1689112b23
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
@ -76,8 +77,10 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]]),
|
||||
expected_steps_per_epoch=2))
|
||||
def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data,
|
||||
expected_steps_per_epoch):
|
||||
def test_get_steps_per_epoch(self, steps_per_epoch: Optional[int],
|
||||
batch_size: Optional[int],
|
||||
train_data: Optional[tf.data.Dataset],
|
||||
expected_steps_per_epoch: int):
|
||||
estimated_steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
batch_size=batch_size,
|
||||
|
@ -130,7 +133,9 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
testcase_name='float16_quantize',
|
||||
config=quantization.QuantizationConfig.for_float16(),
|
||||
model_size=1468))
|
||||
def test_convert_to_tflite_quantized(self, config, model_size):
|
||||
def test_convert_to_tflite_quantized(self,
|
||||
config: quantization.QuantizationConfig,
|
||||
model_size: int):
|
||||
input_dim = 16
|
||||
num_classes = 2
|
||||
max_input_value = 5
|
||||
|
@ -157,5 +162,6 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
test_util.test_tflite_file(
|
||||
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in New Issue
Block a user