Add an export_tflite API to gesture recognizer model maker library.
PiperOrigin-RevId: 482527017
This commit is contained in:
parent
467cd34feb
commit
e71638cf67
|
@ -21,8 +21,6 @@ import abc
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
# Dependency imports
|
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import dataset
|
from mediapipe.model_maker.python.core.data import dataset
|
||||||
|
@ -77,9 +75,9 @@ class CustomModel(abc.ABC):
|
||||||
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
||||||
# TODO: Populate metadata to the exported TFLite model.
|
# TODO: Populate metadata to the exported TFLite model.
|
||||||
model_util.export_tflite(
|
model_util.export_tflite(
|
||||||
self._model,
|
model=self._model,
|
||||||
tflite_filepath,
|
tflite_filepath=tflite_filepath,
|
||||||
quantization_config,
|
quantization_config=quantization_config,
|
||||||
preprocess=preprocess)
|
preprocess=preprocess)
|
||||||
tf.compat.v1.logging.info(
|
tf.compat.v1.logging.info(
|
||||||
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
||||||
|
|
|
@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(CustomModelTest, self).setUp()
|
super(CustomModelTest, self).setUp()
|
||||||
self.model = MockCustomModel(model_spec=None, shuffle=False)
|
self._model = MockCustomModel(model_spec=None, shuffle=False)
|
||||||
self.model._model = test_util.build_model(input_shape=[4], num_classes=2)
|
self._model._model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||||
|
|
||||||
def _check_nonempty_file(self, filepath):
|
def _check_nonempty_file(self, filepath):
|
||||||
self.assertTrue(os.path.isfile(filepath))
|
self.assertTrue(os.path.isfile(filepath))
|
||||||
|
@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_export_tflite(self):
|
def test_export_tflite(self):
|
||||||
export_path = os.path.join(self.get_temp_dir(), 'export/')
|
export_path = os.path.join(self.get_temp_dir(), 'export/')
|
||||||
self.model.export_tflite(export_dir=export_path)
|
self._model.export_tflite(export_dir=export_path)
|
||||||
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
|
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -104,8 +104,8 @@ def export_tflite(
|
||||||
quantization_config: Configuration for post-training quantization.
|
quantization_config: Configuration for post-training quantization.
|
||||||
supported_ops: A list of supported ops in the converted TFLite file.
|
supported_ops: A list of supported ops in the converted TFLite file.
|
||||||
preprocess: A callable to preprocess the representative dataset for
|
preprocess: A callable to preprocess the representative dataset for
|
||||||
quantization. The callable takes three arguments in order: feature,
|
quantization. The callable takes three arguments in order: feature, label,
|
||||||
label, and is_training.
|
and is_training.
|
||||||
"""
|
"""
|
||||||
if tflite_filepath is None:
|
if tflite_filepath is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||||
model_util.export_tflite(model, tflite_file)
|
model_util.export_tflite(model, tflite_file)
|
||||||
self._test_tflite(model, tflite_file, input_dim)
|
test_util.test_tflite(
|
||||||
|
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
dict(
|
dict(
|
||||||
|
@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
input_dim = 16
|
input_dim = 16
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
max_input_value = 5
|
max_input_value = 5
|
||||||
model = test_util.build_model([input_dim], num_classes)
|
model = test_util.build_model(
|
||||||
|
input_shape=[input_dim], num_classes=num_classes)
|
||||||
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
||||||
|
|
||||||
model_util.export_tflite(model, tflite_file, config)
|
model_util.export_tflite(
|
||||||
self._test_tflite(
|
model=model, tflite_filepath=tflite_file, quantization_config=config)
|
||||||
model, tflite_file, input_dim, max_input_value, atol=1e-00)
|
|
||||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
|
||||||
|
|
||||||
def _test_tflite(self,
|
|
||||||
keras_model: tf.keras.Model,
|
|
||||||
tflite_model_file: str,
|
|
||||||
input_dim: int,
|
|
||||||
max_input_value: int = 1000,
|
|
||||||
atol: float = 1e-04):
|
|
||||||
random_input = test_util.create_random_sample(
|
|
||||||
size=[1, input_dim], high=max_input_value)
|
|
||||||
random_input = tf.convert_to_tensor(random_input)
|
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
test_util.is_same_output(
|
test_util.test_tflite(
|
||||||
tflite_model_file, keras_model, random_input, atol=atol))
|
keras_model=model,
|
||||||
|
tflite_file=tflite_file,
|
||||||
|
size=[1, input_dim],
|
||||||
|
high=max_input_value,
|
||||||
|
atol=1e-00))
|
||||||
|
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -92,3 +92,32 @@ def is_same_output(tflite_file: str,
|
||||||
keras_output = keras_model.predict_on_batch(input_tensors)
|
keras_output = keras_model.predict_on_batch(input_tensors)
|
||||||
|
|
||||||
return np.allclose(lite_output, keras_output, atol=atol)
|
return np.allclose(lite_output, keras_output, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tflite(keras_model: tf.keras.Model,
|
||||||
|
tflite_file: str,
|
||||||
|
size: Union[int, List[int]],
|
||||||
|
high: float = 1,
|
||||||
|
atol: float = 1e-04) -> bool:
|
||||||
|
"""Verifies if the output of TFLite model and TF Keras model are identical.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keras_model: Input TensorFlow Keras model.
|
||||||
|
tflite_file: Input TFLite model file.
|
||||||
|
size: Size of the input tesnor.
|
||||||
|
high: Higher boundary of the values in input tensors.
|
||||||
|
atol: Absolute tolerance of the difference between the outputs of Keras
|
||||||
|
model and TFLite model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the output of TFLite model and TF Keras model are identical.
|
||||||
|
Otherwise, False.
|
||||||
|
"""
|
||||||
|
random_input = create_random_sample(size=size, high=high)
|
||||||
|
random_input = tf.convert_to_tensor(random_input)
|
||||||
|
|
||||||
|
return is_same_output(
|
||||||
|
tflite_file=tflite_file,
|
||||||
|
keras_model=keras_model,
|
||||||
|
input_tensors=random_input,
|
||||||
|
atol=atol)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user