mediapipe/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py
2023-09-02 09:56:40 -07:00

318 lines
11 KiB
Python

# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""APIs to train face stylization model."""
import logging
import os
from typing import Any, Callable, Optional
import zipfile
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.face_stylizer import constants
from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_options
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
from mediapipe.tasks.python.metadata.metadata_writers import face_stylizer as metadata_writer
# Face detector model and face landmarks detector file names.
_FACE_DETECTOR_MODEL = 'face_detector.tflite'
_FACE_LANDMARKS_DETECTOR_MODEL = 'face_landmarks_detector.tflite'
# The mean value used in the input tensor normalization for the face stylizer
# model.
_NORM_MEAN = 0.0
_NORM_STD = 255.0
class FaceStylizer(object):
"""FaceStylizer for building face stylization model.
Attributes:
w_avg: An average face latent code to regularize face generation in face
stylization.
"""
def __init__(
self,
model_spec: ms.ModelSpec,
model_options: model_opt.FaceStylizerModelOptions,
hparams: hp.HParams,
):
"""Initializes face stylizer.
Args:
model_spec: Specification for the model.
model_options: Model options for creating face stylizer.
hparams: The hyperparameters for training face stylizer.
"""
self._model_spec = model_spec
self._model_options = model_options
self._hparams = hparams
self._preprocessor = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape,
num_classes=1,
mean_rgb=self._model_spec.mean_rgb,
stddev_rgb=self._model_spec.stddev_rgb,
)
@classmethod
def create(
cls,
train_data: classification_ds.ClassificationDataset,
options: face_stylizer_options.FaceStylizerOptions,
) -> 'FaceStylizer':
"""Creates and trains a face stylizer with input datasets.
Args:
train_data: The input style image dataset for training the face stylizer.
options: The options to configure face stylizer.
Returns:
A FaceStylizer instant with the trained model.
"""
if options.model_options is None:
options.model_options = model_opt.FaceStylizerModelOptions()
if options.hparams is None:
options.hparams = hp.HParams()
spec = ms.SupportedModels.get(options.model)
face_stylizer = cls(
model_spec=spec,
model_options=options.model_options,
hparams=options.hparams,
)
face_stylizer._create_and_train_model(train_data)
return face_stylizer
def stylize(
self, data: classification_ds.ClassificationDataset
) -> classification_ds.ClassificationDataset:
"""Stylizes the images represented by the input dataset.
Args:
data: Dataset of input images, can contain multiple images.
Returns:
A dataset contains the stylized images
"""
input_dataset = data.gen_tf_dataset(preprocess=self._preprocessor)
output_img_list = []
for sample in input_dataset:
image = sample[0]
w = self._encoder(image, training=True)
x = self._decoder({'inputs': w + self.w_avg}, training=True)
output_batch = x['image'][-1]
output_img_tensor = (tf.squeeze(output_batch).numpy() + 1.0) * 127.5
output_img_list.append(output_img_tensor)
image_ds = tf.data.Dataset.from_tensor_slices(output_img_list)
logging.info('Stylized %s images.', len(output_img_list))
return classification_ds.ClassificationDataset(
dataset=image_ds,
label_names=['stylized'],
size=len(output_img_list),
)
def _create_and_train_model(
self, train_data: classification_ds.ClassificationDataset
):
"""Creates and trains the face stylizer model.
Args:
train_data: Training data.
"""
self._create_model()
self._train_model(train_data=train_data, preprocessor=self._preprocessor)
def _create_model(self):
"""Creates the components of face stylizer."""
self._encoder = model_util.load_keras_model(
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
)
self._decoder = model_util.load_keras_model(
constants.FACE_STYLIZER_DECODER_MODEL_FILES.get_path()
)
self._mapping_network = model_util.load_keras_model(
constants.FACE_STYLIZER_MAPPING_MODEL_FILES.get_path()
)
self._discriminator = model_util.load_keras_model(
constants.FACE_STYLIZER_DISCRIMINATOR_MODEL_FILES.get_path()
)
with tf.io.gfile.GFile(
constants.FACE_STYLIZER_W_FILES.get_path(), 'rb'
) as f:
w_avg = np.load(f)
self.w_avg = w_avg[: self._model_spec.style_block_num][np.newaxis]
def _train_model(
self,
train_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., Any]] = None,
):
"""Trains the face stylizer model.
Args:
train_data: The data for training model.
preprocessor: The image preprocessor.
"""
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
# TODO: Support processing multiple input style images. The
# input style images are expected to have similar style.
# style_sample represents a tuple of (style_image, style_label).
style_sample = next(iter(train_dataset))
style_img = style_sample[0]
batch_size = self._hparams.batch_size
label_in = tf.zeros(shape=[batch_size, 0])
style_encoding = self._encoder(style_img, training=True) + self.w_avg
optimizer = tf.keras.optimizers.Adam(
learning_rate=self._hparams.learning_rate,
beta_1=self._hparams.beta_1,
beta_2=self._hparams.beta_2,
)
image_perceptual_quality_loss = loss_functions.ImagePerceptualQualityLoss(
loss_weight=self._model_options.perception_loss_weight
)
for i in range(self._hparams.epochs):
noise = tf.random.normal(shape=[batch_size, constants.STYLE_DIM])
mean_w = self._mapping_network([noise, label_in], training=False)[
:, : self._model_spec.style_block_num
]
style_encodings = tf.tile(style_encoding, [batch_size, 1, 1])
in_latent = tf.Variable(tf.identity(style_encodings))
alpha = self._model_options.alpha
for swap_layer in self._model_options.swap_layers:
in_latent = in_latent[:, swap_layer].assign(
alpha * style_encodings[:, swap_layer]
+ (1 - alpha) * mean_w[:, swap_layer]
)
with tf.GradientTape() as tape:
outputs = self._decoder({'inputs': in_latent.numpy()}, training=True)
gen_img = outputs['image'][-1]
real_feature = self._discriminator(
[tf.transpose(style_img, [0, 3, 1, 2]), label_in]
)
gen_feature = self._discriminator(
[tf.transpose(gen_img, [0, 3, 1, 2]), label_in]
)
style_loss = image_perceptual_quality_loss(gen_img, style_img)
style_loss += (
tf.keras.losses.MeanAbsoluteError()(real_feature, gen_feature)
* self._model_options.adv_loss_weight
)
print(f'Iteration {i} loss: {style_loss.numpy()}')
tvars = self._decoder.trainable_variables
grads = tape.gradient(style_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
def export_model(self, model_name: str = 'face_stylizer.task'):
"""Converts the model to TFLite and exports as a model bundle file.
Saves a model bundle file and metadata json file to hparams.export_dir. The
resulting model bundle file will contain necessary models for face
detection, face landmarks detection, and customized face stylization. Only
the model bundle file is needed for the downstream face stylization task.
The metadata.json file is saved only to interpret the contents of the model
bundle file. The face detection model and face landmarks detection model are
from https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task
and the customized face stylization model is trained in this library.
Args:
model_name: Face stylizer model bundle file name. The full export path is
{self._hparams.export_dir}/{model_name}.
"""
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)
model_bundle_file = os.path.join(self._hparams.export_dir, model_name)
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
# Create an end-to-end model by concatenating encoder and decoder
inputs = tf.keras.Input(shape=(256, 256, 3))
x = self._encoder(inputs, training=True)
x = self._decoder({'inputs': x + self.w_avg}, training=True)
x = x['image'][-1]
# Scale the data range from [-1, 1] to [0, 1] to support running inference
# on both CPU and GPU.
outputs = (x + 1.0) / 2.0
model = tf.keras.Model(inputs=inputs, outputs=outputs)
face_stylizer_model_buffer = model_util.convert_to_tflite(
model=model,
quantization_config=None,
supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess=self._preprocessor,
allow_custom_ops=True,
)
face_aligner_task_file_path = constants.FACE_ALIGNER_TASK_FILES.get_path()
with zipfile.ZipFile(face_aligner_task_file_path, 'r') as zf:
file_list = zf.namelist()
if _FACE_DETECTOR_MODEL not in file_list:
raise ValueError(
'{0} is not packed in face aligner task file'.format(
_FACE_DETECTOR_MODEL
)
)
if _FACE_LANDMARKS_DETECTOR_MODEL not in file_list:
raise ValueError(
'{0} is not packed in face aligner task file'.format(
_FACE_LANDMARKS_DETECTOR_MODEL
)
)
with zf.open(_FACE_DETECTOR_MODEL) as f:
face_detector_model_buffer = f.read()
with zf.open(_FACE_LANDMARKS_DETECTOR_MODEL) as f:
face_landmarks_detector_model_buffer = f.read()
writer = metadata_writer.MetadataWriter.create(
bytearray(face_stylizer_model_buffer),
bytearray(face_detector_model_buffer),
bytearray(face_landmarks_detector_model_buffer),
input_norm_mean=[_NORM_MEAN],
input_norm_std=[_NORM_STD],
)
model_bundle_content, metadata_json = writer.populate()
with open(model_bundle_file, 'wb') as f:
f.write(model_bundle_content)
with open(metadata_file, 'w') as f:
f.write(metadata_json)