add metadata writer into face stylizer.

PiperOrigin-RevId: 555596257
This commit is contained in:
Yuqi Li 2023-08-10 12:07:40 -07:00 committed by Copybara-Service
parent 91f15d8e4a
commit c448d54aa7
3 changed files with 81 additions and 14 deletions

View File

@ -86,6 +86,7 @@ py_library(
"//mediapipe/model_maker/python/core/utils:loss_functions", "//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/vision/core:image_preprocessing", "//mediapipe/model_maker/python/vision/core:image_preprocessing",
"//mediapipe/tasks/python/metadata/metadata_writers:face_stylizer",
], ],
) )

View File

@ -15,6 +15,7 @@
import os import os
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import zipfile
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -28,6 +29,16 @@ from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_opti
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
from mediapipe.tasks.python.metadata.metadata_writers import face_stylizer as metadata_writer
# Face detector model and face landmarks detector file names.
_FACE_DETECTOR_MODEL = 'face_detector.tflite'
_FACE_LANDMARKS_DETECTOR_MODEL = 'face_landmarks_detector.tflite'
# The mean value used in the input tensor normalization for the face stylizer
# model.
_NORM_MEAN = 0.0
_NORM_STD = 255.0
class FaceStylizer(object): class FaceStylizer(object):
@ -197,21 +208,26 @@ class FaceStylizer(object):
grads = tape.gradient(style_loss, tvars) grads = tape.gradient(style_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars))) optimizer.apply_gradients(list(zip(grads, tvars)))
# TODO: Add a metadata writer for face sytlizer model. def export_model(self, model_name: str = 'face_stylizer.task'):
def export_model(self, model_name: str = 'model.tflite'): """Converts the model to TFLite and exports as a model bundle file.
"""Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function Saves a model bundle file and metadata json file to hparams.export_dir. The
also saves a metadata.json file to the same directory as the TFLite file resulting model bundle file will contain necessary models for face
which can be used to interpret the metadata content in the TFLite file. detection, face landmarks detection, and customized face stylization. Only
the model bundle file is needed for the downstream face stylization task.
The metadata.json file is saved only to interpret the contents of the model
bundle file. The face detection model and face landmarks detection model are
from https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task
and the customized face stylization model is trained in this library.
Args: Args:
model_name: File name to save TFLite model with metadata. The full export model_name: Face stylizer model bundle file name. The full export path is
path is {self._hparams.export_dir}/{model_name}. {self._hparams.export_dir}/{model_name}.
""" """
if not tf.io.gfile.exists(self._hparams.export_dir): if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir) tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name) model_bundle_file = os.path.join(self._hparams.export_dir, model_name)
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
# Create an end-to-end model by concatenating encoder and decoder # Create an end-to-end model by concatenating encoder and decoder
inputs = tf.keras.Input(shape=(256, 256, 3)) inputs = tf.keras.Input(shape=(256, 256, 3))
@ -223,8 +239,44 @@ class FaceStylizer(object):
outputs = (x + 1.0) / 2.0 outputs = (x + 1.0) / 2.0
model = tf.keras.Model(inputs=inputs, outputs=outputs) model = tf.keras.Model(inputs=inputs, outputs=outputs)
tflite_model = model_util.convert_to_tflite( face_stylizer_model_buffer = model_util.convert_to_tflite(
model=model, model=model,
preprocess=self._preprocessor, preprocess=self._preprocessor,
) )
model_util.save_tflite(tflite_model, tflite_file)
face_aligner_task_file_path = constants.FACE_ALIGNER_TASK_FILES.get_path()
with zipfile.ZipFile(face_aligner_task_file_path, 'r') as zf:
file_list = zf.namelist()
if _FACE_DETECTOR_MODEL not in file_list:
raise ValueError(
'{0} is not packed in face aligner task file'.format(
_FACE_DETECTOR_MODEL
)
)
if _FACE_LANDMARKS_DETECTOR_MODEL not in file_list:
raise ValueError(
'{0} is not packed in face aligner task file'.format(
_FACE_LANDMARKS_DETECTOR_MODEL
)
)
with zf.open(_FACE_DETECTOR_MODEL) as f:
face_detector_model_buffer = f.read()
with zf.open(_FACE_LANDMARKS_DETECTOR_MODEL) as f:
face_landmarks_detector_model_buffer = f.read()
writer = metadata_writer.MetadataWriter.create(
bytearray(face_stylizer_model_buffer),
bytearray(face_detector_model_buffer),
bytearray(face_landmarks_detector_model_buffer),
input_norm_mean=[_NORM_MEAN],
input_norm_std=[_NORM_STD],
)
model_bundle_content, metadata_json = writer.populate()
with open(model_bundle_file, 'wb') as f:
f.write(model_bundle_content)
with open(metadata_file, 'w') as f:
f.write(metadata_json)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import zipfile
import tensorflow as tf import tensorflow as tf
@ -65,10 +66,23 @@ class FaceStylizerTest(tf.test.TestCase):
model = face_stylizer.FaceStylizer.create( model = face_stylizer.FaceStylizer.create(
train_data=self._train_data, options=face_stylizer_options train_data=self._train_data, options=face_stylizer_options
) )
tflite_model_name = 'custom_face_stylizer.tflite' model.export_model()
model.export_model(model_name=tflite_model_name) model_bundle_file = os.path.join(
self.get_temp_dir(), 'face_stylizer.task'
)
with zipfile.ZipFile(model_bundle_file) as zf:
self.assertEqual(
set(zf.namelist()),
set([
'face_detector.tflite',
'face_landmarks_detector.tflite',
'face_stylizer.tflite',
]),
)
zf.extractall(self.get_temp_dir())
face_stylizer_tflite_file = os.path.join( face_stylizer_tflite_file = os.path.join(
self.get_temp_dir(), tflite_model_name self.get_temp_dir(), 'face_stylizer.tflite'
) )
spec = face_stylizer.SupportedModels.get(model_enum) spec = face_stylizer.SupportedModels.get(model_enum)
input_image_shape = spec.input_image_shape input_image_shape = spec.input_image_shape