Add the TFLite conversion API to BlazeFaceStylizer in model maker.

PiperOrigin-RevId: 527806005
This commit is contained in:
MediaPipe Team 2023-04-28 00:28:31 -07:00 committed by Copybara-Service
parent 5d9761cbfd
commit cf22c97143
2 changed files with 41 additions and 0 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""APIs to train face stylization model.""" """APIs to train face stylization model."""
import os
from typing import Callable, Optional from typing import Callable, Optional
import numpy as np import numpy as np
@ -199,3 +200,32 @@ class FaceStylizer(object):
tvars = self._decoder.trainable_variables tvars = self._decoder.trainable_variables
grads = tape.gradient(style_loss, tvars) grads = tape.gradient(style_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars))) optimizer.apply_gradients(list(zip(grads, tvars)))
# TODO: Add a metadata writer for face sytlizer model.
def export_model(self, model_name: str = 'model.tflite'):
"""Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function
also saves a metadata.json file to the same directory as the TFLite file
which can be used to interpret the metadata content in the TFLite file.
Args:
model_name: File name to save TFLite model with metadata. The full export
path is {self._hparams.export_dir}/{model_name}.
"""
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name)
# Create an end-to-end model by concatenating encoder and decoder
inputs = tf.keras.Input(shape=(256, 256, 3))
x = self._encoder(inputs)
x = self._decoder({'inputs': x + self.w_avg})
outputs = x['image'][-1]
model = tf.keras.Model(inputs=inputs, outputs=outputs)
tflite_model = model_util.convert_to_tflite(
model=model,
preprocess=self._preprocessor,
)
model_util.save_tflite(tflite_model, tflite_file)

View File

@ -50,6 +50,17 @@ class FaceStylizerTest(tf.test.TestCase):
) )
self._evaluate_saved_model(model) self._evaluate_saved_model(model)
def test_export_face_stylizer_tflite_model(self):
with self.test_session(use_gpu=True):
face_stylizer_options = face_stylizer.FaceStylizerOptions(
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
hparams=face_stylizer.HParams(epochs=0),
)
model = face_stylizer.FaceStylizer.create(
train_data=self._train_data, options=face_stylizer_options
)
model.export_model()
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()