Add the TFLite conversion API to BlazeFaceStylizer in model maker.
PiperOrigin-RevId: 527806005
This commit is contained in:
parent
5d9761cbfd
commit
cf22c97143
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""APIs to train face stylization model."""
|
"""APIs to train face stylization model."""
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -199,3 +200,32 @@ class FaceStylizer(object):
|
||||||
tvars = self._decoder.trainable_variables
|
tvars = self._decoder.trainable_variables
|
||||||
grads = tape.gradient(style_loss, tvars)
|
grads = tape.gradient(style_loss, tvars)
|
||||||
optimizer.apply_gradients(list(zip(grads, tvars)))
|
optimizer.apply_gradients(list(zip(grads, tvars)))
|
||||||
|
|
||||||
|
# TODO: Add a metadata writer for face sytlizer model.
|
||||||
|
def export_model(self, model_name: str = 'model.tflite'):
|
||||||
|
"""Converts and saves the model to a TFLite file with metadata included.
|
||||||
|
|
||||||
|
Note that only the TFLite file is needed for deployment. This function
|
||||||
|
also saves a metadata.json file to the same directory as the TFLite file
|
||||||
|
which can be used to interpret the metadata content in the TFLite file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: File name to save TFLite model with metadata. The full export
|
||||||
|
path is {self._hparams.export_dir}/{model_name}.
|
||||||
|
"""
|
||||||
|
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||||
|
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||||
|
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
|
|
||||||
|
# Create an end-to-end model by concatenating encoder and decoder
|
||||||
|
inputs = tf.keras.Input(shape=(256, 256, 3))
|
||||||
|
x = self._encoder(inputs)
|
||||||
|
x = self._decoder({'inputs': x + self.w_avg})
|
||||||
|
outputs = x['image'][-1]
|
||||||
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
|
tflite_model = model_util.convert_to_tflite(
|
||||||
|
model=model,
|
||||||
|
preprocess=self._preprocessor,
|
||||||
|
)
|
||||||
|
model_util.save_tflite(tflite_model, tflite_file)
|
||||||
|
|
|
@ -50,6 +50,17 @@ class FaceStylizerTest(tf.test.TestCase):
|
||||||
)
|
)
|
||||||
self._evaluate_saved_model(model)
|
self._evaluate_saved_model(model)
|
||||||
|
|
||||||
|
def test_export_face_stylizer_tflite_model(self):
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
face_stylizer_options = face_stylizer.FaceStylizerOptions(
|
||||||
|
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
|
||||||
|
hparams=face_stylizer.HParams(epochs=0),
|
||||||
|
)
|
||||||
|
model = face_stylizer.FaceStylizer.create(
|
||||||
|
train_data=self._train_data, options=face_stylizer_options
|
||||||
|
)
|
||||||
|
model.export_model()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user