Add an export_tflite API to gesture recognizer model maker library.

PiperOrigin-RevId: 482527017
This commit is contained in:
MediaPipe Team 2022-10-20 10:19:15 -07:00 committed by Copybara-Service
parent 467cd34feb
commit e71638cf67
5 changed files with 50 additions and 29 deletions

View File

@ -21,8 +21,6 @@ import abc
import os import os
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
# Dependency imports
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset from mediapipe.model_maker.python.core.data import dataset
@ -77,9 +75,9 @@ class CustomModel(abc.ABC):
tflite_filepath = os.path.join(export_dir, tflite_filename) tflite_filepath = os.path.join(export_dir, tflite_filename)
# TODO: Populate metadata to the exported TFLite model. # TODO: Populate metadata to the exported TFLite model.
model_util.export_tflite( model_util.export_tflite(
self._model, model=self._model,
tflite_filepath, tflite_filepath=tflite_filepath,
quantization_config, quantization_config=quantization_config,
preprocess=preprocess) preprocess=preprocess)
tf.compat.v1.logging.info( tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully: %s' % tflite_filepath) 'TensorFlow Lite model exported successfully: %s' % tflite_filepath)

View File

@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(CustomModelTest, self).setUp() super(CustomModelTest, self).setUp()
self.model = MockCustomModel(model_spec=None, shuffle=False) self._model = MockCustomModel(model_spec=None, shuffle=False)
self.model._model = test_util.build_model(input_shape=[4], num_classes=2) self._model._model = test_util.build_model(input_shape=[4], num_classes=2)
def _check_nonempty_file(self, filepath): def _check_nonempty_file(self, filepath):
self.assertTrue(os.path.isfile(filepath)) self.assertTrue(os.path.isfile(filepath))
@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase):
def test_export_tflite(self): def test_export_tflite(self):
export_path = os.path.join(self.get_temp_dir(), 'export/') export_path = os.path.join(self.get_temp_dir(), 'export/')
self.model.export_tflite(export_dir=export_path) self._model.export_tflite(export_dir=export_path)
self._check_nonempty_file(os.path.join(export_path, 'model.tflite')) self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -104,8 +104,8 @@ def export_tflite(
quantization_config: Configuration for post-training quantization. quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file. supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, quantization. The callable takes three arguments in order: feature, label,
label, and is_training. and is_training.
""" """
if tflite_filepath is None: if tflite_filepath is None:
raise ValueError( raise ValueError(

View File

@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
model = test_util.build_model(input_shape=[input_dim], num_classes=2) model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.export_tflite(model, tflite_file) model_util.export_tflite(model, tflite_file)
self._test_tflite(model, tflite_file, input_dim) test_util.test_tflite(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
input_dim = 16 input_dim = 16
num_classes = 2 num_classes = 2
max_input_value = 5 max_input_value = 5
model = test_util.build_model([input_dim], num_classes) model = test_util.build_model(
input_shape=[input_dim], num_classes=num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite') tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite(model, tflite_file, config) model_util.export_tflite(
self._test_tflite( model=model, tflite_filepath=tflite_file, quantization_config=config)
model, tflite_file, input_dim, max_input_value, atol=1e-00)
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
def _test_tflite(self,
keras_model: tf.keras.Model,
tflite_model_file: str,
input_dim: int,
max_input_value: int = 1000,
atol: float = 1e-04):
random_input = test_util.create_random_sample(
size=[1, input_dim], high=max_input_value)
random_input = tf.convert_to_tensor(random_input)
self.assertTrue( self.assertTrue(
test_util.is_same_output( test_util.test_tflite(
tflite_model_file, keras_model, random_input, atol=atol)) keras_model=model,
tflite_file=tflite_file,
size=[1, input_dim],
high=max_input_value,
atol=1e-00))
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -92,3 +92,32 @@ def is_same_output(tflite_file: str,
keras_output = keras_model.predict_on_batch(input_tensors) keras_output = keras_model.predict_on_batch(input_tensors)
return np.allclose(lite_output, keras_output, atol=atol) return np.allclose(lite_output, keras_output, atol=atol)
def test_tflite(keras_model: tf.keras.Model,
tflite_file: str,
size: Union[int, List[int]],
high: float = 1,
atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical.
Args:
keras_model: Input TensorFlow Keras model.
tflite_file: Input TFLite model file.
size: Size of the input tesnor.
high: Higher boundary of the values in input tensors.
atol: Absolute tolerance of the difference between the outputs of Keras
model and TFLite model.
Returns:
True if the output of TFLite model and TF Keras model are identical.
Otherwise, False.
"""
random_input = create_random_sample(size=size, high=high)
random_input = tf.convert_to_tensor(random_input)
return is_same_output(
tflite_file=tflite_file,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)