Add the TFLite conversion API to BlazeFaceStylizer in model maker.

PiperOrigin-RevId: 527806005
This commit is contained in:
MediaPipe Team 2023-04-28 00:28:31 -07:00 committed by Copybara-Service
parent 5d9761cbfd
commit cf22c97143
2 changed files with 41 additions and 0 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""APIs to train face stylization model."""
import os
from typing import Callable, Optional
import numpy as np
@ -199,3 +200,32 @@ class FaceStylizer(object):
tvars = self._decoder.trainable_variables
grads = tape.gradient(style_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
# TODO: Add a metadata writer for face sytlizer model.
def export_model(self, model_name: str = 'model.tflite'):
"""Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function
also saves a metadata.json file to the same directory as the TFLite file
which can be used to interpret the metadata content in the TFLite file.
Args:
model_name: File name to save TFLite model with metadata. The full export
path is {self._hparams.export_dir}/{model_name}.
"""
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name)
# Create an end-to-end model by concatenating encoder and decoder
inputs = tf.keras.Input(shape=(256, 256, 3))
x = self._encoder(inputs)
x = self._decoder({'inputs': x + self.w_avg})
outputs = x['image'][-1]
model = tf.keras.Model(inputs=inputs, outputs=outputs)
tflite_model = model_util.convert_to_tflite(
model=model,
preprocess=self._preprocessor,
)
model_util.save_tflite(tflite_model, tflite_file)

View File

@ -50,6 +50,17 @@ class FaceStylizerTest(tf.test.TestCase):
)
self._evaluate_saved_model(model)
def test_export_face_stylizer_tflite_model(self):
with self.test_session(use_gpu=True):
face_stylizer_options = face_stylizer.FaceStylizerOptions(
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
hparams=face_stylizer.HParams(epochs=0),
)
model = face_stylizer.FaceStylizer.create(
train_data=self._train_data, options=face_stylizer_options
)
model.export_model()
if __name__ == '__main__':
tf.test.main()