add metadata writer into face stylizer.

PiperOrigin-RevId: 555596257
This commit is contained in:
Yuqi Li 2023-08-10 12:07:40 -07:00 committed by Copybara-Service
parent 91f15d8e4a
commit c448d54aa7
3 changed files with 81 additions and 14 deletions

View File

@ -86,6 +86,7 @@ py_library(
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
"//mediapipe/tasks/python/metadata/metadata_writers:face_stylizer",
],
)

View File

@ -15,6 +15,7 @@
import os
from typing import Any, Callable, Optional
import zipfile
import numpy as np
import tensorflow as tf
@ -28,6 +29,16 @@ from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_opti
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
from mediapipe.tasks.python.metadata.metadata_writers import face_stylizer as metadata_writer
# Face detector model and face landmarks detector file names.
_FACE_DETECTOR_MODEL = 'face_detector.tflite'
_FACE_LANDMARKS_DETECTOR_MODEL = 'face_landmarks_detector.tflite'
# The mean value used in the input tensor normalization for the face stylizer
# model.
_NORM_MEAN = 0.0
_NORM_STD = 255.0
class FaceStylizer(object):
@ -197,21 +208,26 @@ class FaceStylizer(object):
grads = tape.gradient(style_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
# TODO: Add a metadata writer for face sytlizer model.
def export_model(self, model_name: str = 'model.tflite'):
"""Converts and saves the model to a TFLite file with metadata included.
def export_model(self, model_name: str = 'face_stylizer.task'):
"""Converts the model to TFLite and exports as a model bundle file.
Note that only the TFLite file is needed for deployment. This function
also saves a metadata.json file to the same directory as the TFLite file
which can be used to interpret the metadata content in the TFLite file.
Saves a model bundle file and metadata json file to hparams.export_dir. The
resulting model bundle file will contain necessary models for face
detection, face landmarks detection, and customized face stylization. Only
the model bundle file is needed for the downstream face stylization task.
The metadata.json file is saved only to interpret the contents of the model
bundle file. The face detection model and face landmarks detection model are
from https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task
and the customized face stylization model is trained in this library.
Args:
model_name: File name to save TFLite model with metadata. The full export
path is {self._hparams.export_dir}/{model_name}.
model_name: Face stylizer model bundle file name. The full export path is
{self._hparams.export_dir}/{model_name}.
"""
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name)
model_bundle_file = os.path.join(self._hparams.export_dir, model_name)
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
# Create an end-to-end model by concatenating encoder and decoder
inputs = tf.keras.Input(shape=(256, 256, 3))
@ -223,8 +239,44 @@ class FaceStylizer(object):
outputs = (x + 1.0) / 2.0
model = tf.keras.Model(inputs=inputs, outputs=outputs)
tflite_model = model_util.convert_to_tflite(
face_stylizer_model_buffer = model_util.convert_to_tflite(
model=model,
preprocess=self._preprocessor,
)
model_util.save_tflite(tflite_model, tflite_file)
face_aligner_task_file_path = constants.FACE_ALIGNER_TASK_FILES.get_path()
with zipfile.ZipFile(face_aligner_task_file_path, 'r') as zf:
file_list = zf.namelist()
if _FACE_DETECTOR_MODEL not in file_list:
raise ValueError(
'{0} is not packed in face aligner task file'.format(
_FACE_DETECTOR_MODEL
)
)
if _FACE_LANDMARKS_DETECTOR_MODEL not in file_list:
raise ValueError(
'{0} is not packed in face aligner task file'.format(
_FACE_LANDMARKS_DETECTOR_MODEL
)
)
with zf.open(_FACE_DETECTOR_MODEL) as f:
face_detector_model_buffer = f.read()
with zf.open(_FACE_LANDMARKS_DETECTOR_MODEL) as f:
face_landmarks_detector_model_buffer = f.read()
writer = metadata_writer.MetadataWriter.create(
bytearray(face_stylizer_model_buffer),
bytearray(face_detector_model_buffer),
bytearray(face_landmarks_detector_model_buffer),
input_norm_mean=[_NORM_MEAN],
input_norm_std=[_NORM_STD],
)
model_bundle_content, metadata_json = writer.populate()
with open(model_bundle_file, 'wb') as f:
f.write(model_bundle_content)
with open(metadata_file, 'w') as f:
f.write(metadata_json)

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
import zipfile
import tensorflow as tf
@ -65,10 +66,23 @@ class FaceStylizerTest(tf.test.TestCase):
model = face_stylizer.FaceStylizer.create(
train_data=self._train_data, options=face_stylizer_options
)
tflite_model_name = 'custom_face_stylizer.tflite'
model.export_model(model_name=tflite_model_name)
model.export_model()
model_bundle_file = os.path.join(
self.get_temp_dir(), 'face_stylizer.task'
)
with zipfile.ZipFile(model_bundle_file) as zf:
self.assertEqual(
set(zf.namelist()),
set([
'face_detector.tflite',
'face_landmarks_detector.tflite',
'face_stylizer.tflite',
]),
)
zf.extractall(self.get_temp_dir())
face_stylizer_tflite_file = os.path.join(
self.get_temp_dir(), tflite_model_name
self.get_temp_dir(), 'face_stylizer.tflite'
)
spec = face_stylizer.SupportedModels.get(model_enum)
input_image_shape = spec.input_image_shape