Improve model_util_test code.
PiperOrigin-RevId: 488752497
This commit is contained in:
parent
a94564540b
commit
1689112b23
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -76,8 +77,10 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||||
[1, 0]]),
|
[1, 0]]),
|
||||||
expected_steps_per_epoch=2))
|
expected_steps_per_epoch=2))
|
||||||
def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data,
|
def test_get_steps_per_epoch(self, steps_per_epoch: Optional[int],
|
||||||
expected_steps_per_epoch):
|
batch_size: Optional[int],
|
||||||
|
train_data: Optional[tf.data.Dataset],
|
||||||
|
expected_steps_per_epoch: int):
|
||||||
estimated_steps_per_epoch = model_util.get_steps_per_epoch(
|
estimated_steps_per_epoch = model_util.get_steps_per_epoch(
|
||||||
steps_per_epoch=steps_per_epoch,
|
steps_per_epoch=steps_per_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -130,7 +133,9 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
testcase_name='float16_quantize',
|
testcase_name='float16_quantize',
|
||||||
config=quantization.QuantizationConfig.for_float16(),
|
config=quantization.QuantizationConfig.for_float16(),
|
||||||
model_size=1468))
|
model_size=1468))
|
||||||
def test_convert_to_tflite_quantized(self, config, model_size):
|
def test_convert_to_tflite_quantized(self,
|
||||||
|
config: quantization.QuantizationConfig,
|
||||||
|
model_size: int):
|
||||||
input_dim = 16
|
input_dim = 16
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
max_input_value = 5
|
max_input_value = 5
|
||||||
|
@ -157,5 +162,6 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
test_util.test_tflite_file(
|
test_util.test_tflite_file(
|
||||||
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user