Add an export_tflite API to gesture recognizer model maker library.
PiperOrigin-RevId: 482527017
This commit is contained in:
parent
467cd34feb
commit
e71638cf67
|
@ -21,8 +21,6 @@ import abc
|
|||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
|
@ -77,9 +75,9 @@ class CustomModel(abc.ABC):
|
|||
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
||||
# TODO: Populate metadata to the exported TFLite model.
|
||||
model_util.export_tflite(
|
||||
self._model,
|
||||
tflite_filepath,
|
||||
quantization_config,
|
||||
model=self._model,
|
||||
tflite_filepath=tflite_filepath,
|
||||
quantization_config=quantization_config,
|
||||
preprocess=preprocess)
|
||||
tf.compat.v1.logging.info(
|
||||
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
||||
|
|
|
@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super(CustomModelTest, self).setUp()
|
||||
self.model = MockCustomModel(model_spec=None, shuffle=False)
|
||||
self.model._model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
self._model = MockCustomModel(model_spec=None, shuffle=False)
|
||||
self._model._model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
|
||||
def _check_nonempty_file(self, filepath):
|
||||
self.assertTrue(os.path.isfile(filepath))
|
||||
|
@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase):
|
|||
|
||||
def test_export_tflite(self):
|
||||
export_path = os.path.join(self.get_temp_dir(), 'export/')
|
||||
self.model.export_tflite(export_dir=export_path)
|
||||
self._model.export_tflite(export_dir=export_path)
|
||||
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -104,8 +104,8 @@ def export_tflite(
|
|||
quantization_config: Configuration for post-training quantization.
|
||||
supported_ops: A list of supported ops in the converted TFLite file.
|
||||
preprocess: A callable to preprocess the representative dataset for
|
||||
quantization. The callable takes three arguments in order: feature,
|
||||
label, and is_training.
|
||||
quantization. The callable takes three arguments in order: feature, label,
|
||||
and is_training.
|
||||
"""
|
||||
if tflite_filepath is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.export_tflite(model, tflite_file)
|
||||
self._test_tflite(model, tflite_file, input_dim)
|
||||
test_util.test_tflite(
|
||||
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
|
@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
input_dim = 16
|
||||
num_classes = 2
|
||||
max_input_value = 5
|
||||
model = test_util.build_model([input_dim], num_classes)
|
||||
model = test_util.build_model(
|
||||
input_shape=[input_dim], num_classes=num_classes)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
||||
|
||||
model_util.export_tflite(model, tflite_file, config)
|
||||
self._test_tflite(
|
||||
model, tflite_file, input_dim, max_input_value, atol=1e-00)
|
||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||
|
||||
def _test_tflite(self,
|
||||
keras_model: tf.keras.Model,
|
||||
tflite_model_file: str,
|
||||
input_dim: int,
|
||||
max_input_value: int = 1000,
|
||||
atol: float = 1e-04):
|
||||
random_input = test_util.create_random_sample(
|
||||
size=[1, input_dim], high=max_input_value)
|
||||
random_input = tf.convert_to_tensor(random_input)
|
||||
|
||||
model_util.export_tflite(
|
||||
model=model, tflite_filepath=tflite_file, quantization_config=config)
|
||||
self.assertTrue(
|
||||
test_util.is_same_output(
|
||||
tflite_model_file, keras_model, random_input, atol=atol))
|
||||
test_util.test_tflite(
|
||||
keras_model=model,
|
||||
tflite_file=tflite_file,
|
||||
size=[1, input_dim],
|
||||
high=max_input_value,
|
||||
atol=1e-00))
|
||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -92,3 +92,32 @@ def is_same_output(tflite_file: str,
|
|||
keras_output = keras_model.predict_on_batch(input_tensors)
|
||||
|
||||
return np.allclose(lite_output, keras_output, atol=atol)
|
||||
|
||||
|
||||
def test_tflite(keras_model: tf.keras.Model,
|
||||
tflite_file: str,
|
||||
size: Union[int, List[int]],
|
||||
high: float = 1,
|
||||
atol: float = 1e-04) -> bool:
|
||||
"""Verifies if the output of TFLite model and TF Keras model are identical.
|
||||
|
||||
Args:
|
||||
keras_model: Input TensorFlow Keras model.
|
||||
tflite_file: Input TFLite model file.
|
||||
size: Size of the input tesnor.
|
||||
high: Higher boundary of the values in input tensors.
|
||||
atol: Absolute tolerance of the difference between the outputs of Keras
|
||||
model and TFLite model.
|
||||
|
||||
Returns:
|
||||
True if the output of TFLite model and TF Keras model are identical.
|
||||
Otherwise, False.
|
||||
"""
|
||||
random_input = create_random_sample(size=size, high=high)
|
||||
random_input = tf.convert_to_tensor(random_input)
|
||||
|
||||
return is_same_output(
|
||||
tflite_file=tflite_file,
|
||||
keras_model=keras_model,
|
||||
input_tensors=random_input,
|
||||
atol=atol)
|
||||
|
|
Loading…
Reference in New Issue
Block a user