add metadata writer into face stylizer.
PiperOrigin-RevId: 555596257
This commit is contained in:
parent
91f15d8e4a
commit
c448d54aa7
|
@ -86,6 +86,7 @@ py_library(
|
||||||
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:face_stylizer",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
import zipfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -28,6 +29,16 @@ from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_opti
|
||||||
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp
|
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
|
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
|
||||||
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
|
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import face_stylizer as metadata_writer
|
||||||
|
|
||||||
|
# Face detector model and face landmarks detector file names.
|
||||||
|
_FACE_DETECTOR_MODEL = 'face_detector.tflite'
|
||||||
|
_FACE_LANDMARKS_DETECTOR_MODEL = 'face_landmarks_detector.tflite'
|
||||||
|
|
||||||
|
# The mean value used in the input tensor normalization for the face stylizer
|
||||||
|
# model.
|
||||||
|
_NORM_MEAN = 0.0
|
||||||
|
_NORM_STD = 255.0
|
||||||
|
|
||||||
|
|
||||||
class FaceStylizer(object):
|
class FaceStylizer(object):
|
||||||
|
@ -197,21 +208,26 @@ class FaceStylizer(object):
|
||||||
grads = tape.gradient(style_loss, tvars)
|
grads = tape.gradient(style_loss, tvars)
|
||||||
optimizer.apply_gradients(list(zip(grads, tvars)))
|
optimizer.apply_gradients(list(zip(grads, tvars)))
|
||||||
|
|
||||||
# TODO: Add a metadata writer for face sytlizer model.
|
def export_model(self, model_name: str = 'face_stylizer.task'):
|
||||||
def export_model(self, model_name: str = 'model.tflite'):
|
"""Converts the model to TFLite and exports as a model bundle file.
|
||||||
"""Converts and saves the model to a TFLite file with metadata included.
|
|
||||||
|
|
||||||
Note that only the TFLite file is needed for deployment. This function
|
Saves a model bundle file and metadata json file to hparams.export_dir. The
|
||||||
also saves a metadata.json file to the same directory as the TFLite file
|
resulting model bundle file will contain necessary models for face
|
||||||
which can be used to interpret the metadata content in the TFLite file.
|
detection, face landmarks detection, and customized face stylization. Only
|
||||||
|
the model bundle file is needed for the downstream face stylization task.
|
||||||
|
The metadata.json file is saved only to interpret the contents of the model
|
||||||
|
bundle file. The face detection model and face landmarks detection model are
|
||||||
|
from https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task
|
||||||
|
and the customized face stylization model is trained in this library.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: File name to save TFLite model with metadata. The full export
|
model_name: Face stylizer model bundle file name. The full export path is
|
||||||
path is {self._hparams.export_dir}/{model_name}.
|
{self._hparams.export_dir}/{model_name}.
|
||||||
"""
|
"""
|
||||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
model_bundle_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
|
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
|
||||||
|
|
||||||
# Create an end-to-end model by concatenating encoder and decoder
|
# Create an end-to-end model by concatenating encoder and decoder
|
||||||
inputs = tf.keras.Input(shape=(256, 256, 3))
|
inputs = tf.keras.Input(shape=(256, 256, 3))
|
||||||
|
@ -223,8 +239,44 @@ class FaceStylizer(object):
|
||||||
outputs = (x + 1.0) / 2.0
|
outputs = (x + 1.0) / 2.0
|
||||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
tflite_model = model_util.convert_to_tflite(
|
face_stylizer_model_buffer = model_util.convert_to_tflite(
|
||||||
model=model,
|
model=model,
|
||||||
preprocess=self._preprocessor,
|
preprocess=self._preprocessor,
|
||||||
)
|
)
|
||||||
model_util.save_tflite(tflite_model, tflite_file)
|
|
||||||
|
face_aligner_task_file_path = constants.FACE_ALIGNER_TASK_FILES.get_path()
|
||||||
|
|
||||||
|
with zipfile.ZipFile(face_aligner_task_file_path, 'r') as zf:
|
||||||
|
file_list = zf.namelist()
|
||||||
|
if _FACE_DETECTOR_MODEL not in file_list:
|
||||||
|
raise ValueError(
|
||||||
|
'{0} is not packed in face aligner task file'.format(
|
||||||
|
_FACE_DETECTOR_MODEL
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if _FACE_LANDMARKS_DETECTOR_MODEL not in file_list:
|
||||||
|
raise ValueError(
|
||||||
|
'{0} is not packed in face aligner task file'.format(
|
||||||
|
_FACE_LANDMARKS_DETECTOR_MODEL
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with zf.open(_FACE_DETECTOR_MODEL) as f:
|
||||||
|
face_detector_model_buffer = f.read()
|
||||||
|
|
||||||
|
with zf.open(_FACE_LANDMARKS_DETECTOR_MODEL) as f:
|
||||||
|
face_landmarks_detector_model_buffer = f.read()
|
||||||
|
|
||||||
|
writer = metadata_writer.MetadataWriter.create(
|
||||||
|
bytearray(face_stylizer_model_buffer),
|
||||||
|
bytearray(face_detector_model_buffer),
|
||||||
|
bytearray(face_landmarks_detector_model_buffer),
|
||||||
|
input_norm_mean=[_NORM_MEAN],
|
||||||
|
input_norm_std=[_NORM_STD],
|
||||||
|
)
|
||||||
|
|
||||||
|
model_bundle_content, metadata_json = writer.populate()
|
||||||
|
with open(model_bundle_file, 'wb') as f:
|
||||||
|
f.write(model_bundle_content)
|
||||||
|
with open(metadata_file, 'w') as f:
|
||||||
|
f.write(metadata_json)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import zipfile
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -65,10 +66,23 @@ class FaceStylizerTest(tf.test.TestCase):
|
||||||
model = face_stylizer.FaceStylizer.create(
|
model = face_stylizer.FaceStylizer.create(
|
||||||
train_data=self._train_data, options=face_stylizer_options
|
train_data=self._train_data, options=face_stylizer_options
|
||||||
)
|
)
|
||||||
tflite_model_name = 'custom_face_stylizer.tflite'
|
model.export_model()
|
||||||
model.export_model(model_name=tflite_model_name)
|
model_bundle_file = os.path.join(
|
||||||
|
self.get_temp_dir(), 'face_stylizer.task'
|
||||||
|
)
|
||||||
|
with zipfile.ZipFile(model_bundle_file) as zf:
|
||||||
|
self.assertEqual(
|
||||||
|
set(zf.namelist()),
|
||||||
|
set([
|
||||||
|
'face_detector.tflite',
|
||||||
|
'face_landmarks_detector.tflite',
|
||||||
|
'face_stylizer.tflite',
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
zf.extractall(self.get_temp_dir())
|
||||||
|
|
||||||
face_stylizer_tflite_file = os.path.join(
|
face_stylizer_tflite_file = os.path.join(
|
||||||
self.get_temp_dir(), tflite_model_name
|
self.get_temp_dir(), 'face_stylizer.tflite'
|
||||||
)
|
)
|
||||||
spec = face_stylizer.SupportedModels.get(model_enum)
|
spec = face_stylizer.SupportedModels.get(model_enum)
|
||||||
input_image_shape = spec.input_image_shape
|
input_image_shape = spec.input_image_shape
|
||||||
|
|
Loading…
Reference in New Issue
Block a user