Add the TFLite conversion API to BlazeFaceStylizer in model maker.
PiperOrigin-RevId: 527806005
This commit is contained in:
parent
5d9761cbfd
commit
cf22c97143
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""APIs to train face stylization model."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -199,3 +200,32 @@ class FaceStylizer(object):
|
|||
tvars = self._decoder.trainable_variables
|
||||
grads = tape.gradient(style_loss, tvars)
|
||||
optimizer.apply_gradients(list(zip(grads, tvars)))
|
||||
|
||||
# TODO: Add a metadata writer for face sytlizer model.
|
||||
def export_model(self, model_name: str = 'model.tflite'):
|
||||
"""Converts and saves the model to a TFLite file with metadata included.
|
||||
|
||||
Note that only the TFLite file is needed for deployment. This function
|
||||
also saves a metadata.json file to the same directory as the TFLite file
|
||||
which can be used to interpret the metadata content in the TFLite file.
|
||||
|
||||
Args:
|
||||
model_name: File name to save TFLite model with metadata. The full export
|
||||
path is {self._hparams.export_dir}/{model_name}.
|
||||
"""
|
||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
|
||||
# Create an end-to-end model by concatenating encoder and decoder
|
||||
inputs = tf.keras.Input(shape=(256, 256, 3))
|
||||
x = self._encoder(inputs)
|
||||
x = self._decoder({'inputs': x + self.w_avg})
|
||||
outputs = x['image'][-1]
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
tflite_model = model_util.convert_to_tflite(
|
||||
model=model,
|
||||
preprocess=self._preprocessor,
|
||||
)
|
||||
model_util.save_tflite(tflite_model, tflite_file)
|
||||
|
|
|
@ -50,6 +50,17 @@ class FaceStylizerTest(tf.test.TestCase):
|
|||
)
|
||||
self._evaluate_saved_model(model)
|
||||
|
||||
def test_export_face_stylizer_tflite_model(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
face_stylizer_options = face_stylizer.FaceStylizerOptions(
|
||||
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
|
||||
hparams=face_stylizer.HParams(epochs=0),
|
||||
)
|
||||
model = face_stylizer.FaceStylizer.create(
|
||||
train_data=self._train_data, options=face_stylizer_options
|
||||
)
|
||||
model.export_model()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in New Issue
Block a user