add metadata writer into face stylizer.
PiperOrigin-RevId: 555596257
This commit is contained in:
parent
91f15d8e4a
commit
c448d54aa7
|
@ -86,6 +86,7 @@ py_library(
|
|||
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:face_stylizer",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
import zipfile
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -28,6 +29,16 @@ from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_opti
|
|||
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import face_stylizer as metadata_writer
|
||||
|
||||
# Face detector model and face landmarks detector file names.
|
||||
_FACE_DETECTOR_MODEL = 'face_detector.tflite'
|
||||
_FACE_LANDMARKS_DETECTOR_MODEL = 'face_landmarks_detector.tflite'
|
||||
|
||||
# The mean value used in the input tensor normalization for the face stylizer
|
||||
# model.
|
||||
_NORM_MEAN = 0.0
|
||||
_NORM_STD = 255.0
|
||||
|
||||
|
||||
class FaceStylizer(object):
|
||||
|
@ -197,21 +208,26 @@ class FaceStylizer(object):
|
|||
grads = tape.gradient(style_loss, tvars)
|
||||
optimizer.apply_gradients(list(zip(grads, tvars)))
|
||||
|
||||
# TODO: Add a metadata writer for face sytlizer model.
|
||||
def export_model(self, model_name: str = 'model.tflite'):
|
||||
"""Converts and saves the model to a TFLite file with metadata included.
|
||||
def export_model(self, model_name: str = 'face_stylizer.task'):
|
||||
"""Converts the model to TFLite and exports as a model bundle file.
|
||||
|
||||
Note that only the TFLite file is needed for deployment. This function
|
||||
also saves a metadata.json file to the same directory as the TFLite file
|
||||
which can be used to interpret the metadata content in the TFLite file.
|
||||
Saves a model bundle file and metadata json file to hparams.export_dir. The
|
||||
resulting model bundle file will contain necessary models for face
|
||||
detection, face landmarks detection, and customized face stylization. Only
|
||||
the model bundle file is needed for the downstream face stylization task.
|
||||
The metadata.json file is saved only to interpret the contents of the model
|
||||
bundle file. The face detection model and face landmarks detection model are
|
||||
from https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task
|
||||
and the customized face stylization model is trained in this library.
|
||||
|
||||
Args:
|
||||
model_name: File name to save TFLite model with metadata. The full export
|
||||
path is {self._hparams.export_dir}/{model_name}.
|
||||
model_name: Face stylizer model bundle file name. The full export path is
|
||||
{self._hparams.export_dir}/{model_name}.
|
||||
"""
|
||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
model_bundle_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
|
||||
|
||||
# Create an end-to-end model by concatenating encoder and decoder
|
||||
inputs = tf.keras.Input(shape=(256, 256, 3))
|
||||
|
@ -223,8 +239,44 @@ class FaceStylizer(object):
|
|||
outputs = (x + 1.0) / 2.0
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
tflite_model = model_util.convert_to_tflite(
|
||||
face_stylizer_model_buffer = model_util.convert_to_tflite(
|
||||
model=model,
|
||||
preprocess=self._preprocessor,
|
||||
)
|
||||
model_util.save_tflite(tflite_model, tflite_file)
|
||||
|
||||
face_aligner_task_file_path = constants.FACE_ALIGNER_TASK_FILES.get_path()
|
||||
|
||||
with zipfile.ZipFile(face_aligner_task_file_path, 'r') as zf:
|
||||
file_list = zf.namelist()
|
||||
if _FACE_DETECTOR_MODEL not in file_list:
|
||||
raise ValueError(
|
||||
'{0} is not packed in face aligner task file'.format(
|
||||
_FACE_DETECTOR_MODEL
|
||||
)
|
||||
)
|
||||
if _FACE_LANDMARKS_DETECTOR_MODEL not in file_list:
|
||||
raise ValueError(
|
||||
'{0} is not packed in face aligner task file'.format(
|
||||
_FACE_LANDMARKS_DETECTOR_MODEL
|
||||
)
|
||||
)
|
||||
|
||||
with zf.open(_FACE_DETECTOR_MODEL) as f:
|
||||
face_detector_model_buffer = f.read()
|
||||
|
||||
with zf.open(_FACE_LANDMARKS_DETECTOR_MODEL) as f:
|
||||
face_landmarks_detector_model_buffer = f.read()
|
||||
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
bytearray(face_stylizer_model_buffer),
|
||||
bytearray(face_detector_model_buffer),
|
||||
bytearray(face_landmarks_detector_model_buffer),
|
||||
input_norm_mean=[_NORM_MEAN],
|
||||
input_norm_std=[_NORM_STD],
|
||||
)
|
||||
|
||||
model_bundle_content, metadata_json = writer.populate()
|
||||
with open(model_bundle_file, 'wb') as f:
|
||||
f.write(model_bundle_content)
|
||||
with open(metadata_file, 'w') as f:
|
||||
f.write(metadata_json)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -65,10 +66,23 @@ class FaceStylizerTest(tf.test.TestCase):
|
|||
model = face_stylizer.FaceStylizer.create(
|
||||
train_data=self._train_data, options=face_stylizer_options
|
||||
)
|
||||
tflite_model_name = 'custom_face_stylizer.tflite'
|
||||
model.export_model(model_name=tflite_model_name)
|
||||
model.export_model()
|
||||
model_bundle_file = os.path.join(
|
||||
self.get_temp_dir(), 'face_stylizer.task'
|
||||
)
|
||||
with zipfile.ZipFile(model_bundle_file) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set([
|
||||
'face_detector.tflite',
|
||||
'face_landmarks_detector.tflite',
|
||||
'face_stylizer.tflite',
|
||||
]),
|
||||
)
|
||||
zf.extractall(self.get_temp_dir())
|
||||
|
||||
face_stylizer_tflite_file = os.path.join(
|
||||
self.get_temp_dir(), tflite_model_name
|
||||
self.get_temp_dir(), 'face_stylizer.tflite'
|
||||
)
|
||||
spec = face_stylizer.SupportedModels.get(model_enum)
|
||||
input_image_shape = spec.input_image_shape
|
||||
|
|
Loading…
Reference in New Issue
Block a user