Add customizable face stylizer module in MediaPipe model maker

PiperOrigin-RevId: 526883862
This commit is contained in:
MediaPipe Team 2023-04-25 00:45:26 -07:00 committed by Copybara-Service
parent a0eb1b696c
commit 56df724c36
6 changed files with 369 additions and 6 deletions

View File

@ -14,6 +14,7 @@
# Placeholder for internal Python strict test compatibility macro. # Placeholder for internal Python strict test compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python GPU test rule.
licenses(["notice"]) licenses(["notice"])
@ -26,6 +27,12 @@ filegroup(
]), ]),
) )
py_library(
name = "constants",
srcs = ["constants.py"],
deps = ["//mediapipe/model_maker/python/core/utils:file_util"],
)
py_library( py_library(
name = "hyperparameters", name = "hyperparameters",
srcs = ["hyperparameters.py"], srcs = ["hyperparameters.py"],
@ -37,6 +44,7 @@ py_library(
py_library( py_library(
name = "model_options", name = "model_options",
srcs = ["model_options.py"], srcs = ["model_options.py"],
deps = ["//mediapipe/model_maker/python/core/utils:loss_functions"],
) )
py_library( py_library(
@ -72,11 +80,39 @@ py_library(
py_test( py_test(
name = "dataset_test", name = "dataset_test",
srcs = ["dataset_test.py"], srcs = ["dataset_test.py"],
data = [ data = [":testdata"],
":testdata",
],
deps = [ deps = [
":dataset", ":dataset",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )
py_library(
name = "face_stylizer",
srcs = ["face_stylizer.py"],
deps = [
":constants",
":face_stylizer_options",
":hyperparameters",
":model_options",
":model_spec",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
],
)
py_library(
name = "face_stylizer_import",
srcs = ["__init__.py"],
visibility = ["//visibility:public"],
deps = [
":dataset",
":face_stylizer",
":face_stylizer_options",
":hyperparameters",
":model_options",
":model_spec",
],
)

View File

@ -12,3 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""MediaPipe Model Maker Python Public API For Face Stylization.""" """MediaPipe Model Maker Python Public API For Face Stylization."""
from mediapipe.model_maker.python.vision.face_stylizer import dataset
from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer
from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_options
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters
from mediapipe.model_maker.python.vision.face_stylizer import model_options
from mediapipe.model_maker.python.vision.face_stylizer import model_spec
FaceStylizer = face_stylizer.FaceStylizer
SupportedModels = model_spec.SupportedModels
ModelOptions = model_options.FaceStylizerModelOptions
HParams = hyperparameters.HParams
Dataset = dataset.Dataset
FaceStylizerOptions = face_stylizer_options.FaceStylizerOptions

View File

@ -0,0 +1,45 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Face stylizer model constants."""
from mediapipe.model_maker.python.core.utils import file_util
# TODO: Move model files to GCS for downloading.
FACE_STYLIZER_ENCODER_MODEL_FILES = file_util.DownloadedFiles(
'face_stylizer/encoder',
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_encoder.tar.gz',
is_folder=True,
)
FACE_STYLIZER_DECODER_MODEL_FILES = file_util.DownloadedFiles(
'face_stylizer/decoder',
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_decoder.tar.gz',
is_folder=True,
)
FACE_STYLIZER_MAPPING_MODEL_FILES = file_util.DownloadedFiles(
'face_stylizer/mapping',
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_mapping.tar.gz',
is_folder=True,
)
FACE_STYLIZER_DISCRIMINATOR_MODEL_FILES = file_util.DownloadedFiles(
'face_stylizer/discriminator',
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_discriminator.tar.gz',
is_folder=True,
)
FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
'face_stylizer/w_avg.npy',
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
)
# Dimension of the input style vector to the decoder
STYLE_DIM = 512

View File

@ -0,0 +1,201 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""APIs to train face stylization model."""
from typing import Callable, Optional
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.face_stylizer import constants
from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_options
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
class FaceStylizer(object):
"""FaceStylizer for building face stylization model.
Attributes:
w_avg: An average face latent code to regularize face generation in face
stylization.
"""
def __init__(
self,
model_spec: ms.ModelSpec,
model_options: model_opt.FaceStylizerModelOptions,
hparams: hp.HParams,
):
"""Initializes face stylizer.
Args:
model_spec: Specification for the model.
model_options: Model options for creating face stylizer.
hparams: The hyperparameters for training face stylizer.
"""
self._model_spec = model_spec
self._model_options = model_options
self._hparams = hparams
# TODO: Support face alignment in image preprocessor.
self._preprocessor = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape,
num_classes=1,
mean_rgb=self._model_spec.mean_rgb,
stddev_rgb=self._model_spec.stddev_rgb,
)
@classmethod
def create(
cls,
train_data: classification_ds.ClassificationDataset,
options: face_stylizer_options.FaceStylizerOptions,
) -> 'FaceStylizer':
"""Creates and trains a face stylizer with input datasets.
Args:
train_data: The input style image dataset for training the face stylizer.
options: The options to configure face stylizer.
Returns:
A FaceStylizer instant with the trained model.
"""
if options.model_options is None:
options.model_options = model_opt.FaceStylizerModelOptions()
if options.hparams is None:
options.hparams = hp.HParams()
spec = ms.SupportedModels.get(options.model)
face_stylizer = cls(
model_spec=spec,
model_options=options.model_options,
hparams=options.hparams,
)
face_stylizer._create_and_train_model(train_data)
return face_stylizer
def _create_and_train_model(
self, train_data: classification_ds.ClassificationDataset
):
"""Creates and trains the face stylizer model.
Args:
train_data: Training data.
"""
self._create_model()
self._train_model(train_data=train_data, preprocessor=self._preprocessor)
def _create_model(self):
"""Creates the componenets of face stylizer."""
self._encoder = model_util.load_keras_model(
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
)
self._decoder = model_util.load_keras_model(
constants.FACE_STYLIZER_DECODER_MODEL_FILES.get_path()
)
self._mapping_network = model_util.load_keras_model(
constants.FACE_STYLIZER_MAPPING_MODEL_FILES.get_path()
)
self._discriminator = model_util.load_keras_model(
constants.FACE_STYLIZER_DISCRIMINATOR_MODEL_FILES.get_path()
)
with tf.io.gfile.GFile(
constants.FACE_STYLIZER_W_FILES.get_path(), 'rb'
) as f:
w_avg = np.load(f)
self.w_avg = w_avg[: self._model_spec.style_block_num][np.newaxis]
def _train_model(
self,
train_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None,
):
"""Trains the face stylizer model.
Args:
train_data: The data for training model.
preprocessor: The image preprocessor.
"""
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
# TODO: Support processing mulitple input style images. The
# input style images are expected to have similar style.
# style_sample represents a tuple of (style_image, style_label).
style_sample = next(iter(train_dataset))
style_img = style_sample[0]
batch_size = self._hparams.batch_size
label_in = tf.zeros(shape=[batch_size, 0])
style_encoding = self._encoder(style_img)
optimizer = tf.keras.optimizers.Adam(
learning_rate=self._hparams.learning_rate,
beta_1=self._hparams.beta_1,
beta_2=self._hparams.beta_2,
)
image_perceptual_quality_loss = loss_functions.ImagePerceptualQualityLoss(
loss_weight=self._model_options.perception_loss_weight
)
for i in range(self._hparams.epochs):
noise = tf.random.normal(shape=[batch_size, constants.STYLE_DIM])
mean_w = self._mapping_network([noise, label_in], training=False)[
:, : self._model_spec.style_block_num
]
style_encodings = tf.tile(style_encoding, [batch_size, 1, 1])
in_latent = tf.Variable(tf.identity(style_encodings))
alpha = self._model_options.alpha
for swap_layer in self._model_options.swap_layers:
in_latent = in_latent[:, swap_layer].assign(
alpha * style_encodings[:, swap_layer]
+ (1 - alpha) * mean_w[:, swap_layer]
)
with tf.GradientTape() as tape:
outputs = self._decoder(
{'inputs': in_latent + self.w_avg},
training=False,
)
gen_img = outputs['image'][-1]
real_feature = self._discriminator(
[tf.transpose(style_img, [0, 3, 1, 2]), label_in]
)
gen_feature = self._discriminator(
[tf.transpose(gen_img, [0, 3, 1, 2]), label_in]
)
style_loss = image_perceptual_quality_loss(gen_img, style_img)
style_loss += (
tf.keras.losses.MeanAbsoluteError()(real_feature, gen_feature)
* self._model_options.adv_loss_weight
)
tf.compat.v1.logging.info(f'Iteration {i} loss: {style_loss.numpy()}')
tvars = self._decoder.trainable_variables
grads = tape.gradient(style_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))

View File

@ -0,0 +1,55 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mediapipe.model_maker.python.vision import face_stylizer
from mediapipe.tasks.python.test import test_utils
class FaceStylizerTest(tf.test.TestCase):
def _load_data(self):
"""Loads training dataset."""
input_data_dir = test_utils.get_test_data_path('testdata')
data = face_stylizer.Dataset.from_folder(dirname=input_data_dir)
return data
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
"""Evaluates the fine-tuned face stylizer model."""
test_image = tf.ones(shape=(256, 256, 3), dtype=tf.float32)
test_image_batch = test_image[tf.newaxis]
in_latent = model._encoder(test_image_batch)
output = model._decoder({'inputs': in_latent + model.w_avg})
self.assertEqual(output['image'][-1].shape, (1, 256, 256, 3))
def setUp(self):
super().setUp()
self._train_data = self._load_data()
def test_finetuning_face_stylizer_with_single_input_style_image(self):
with self.test_session(use_gpu=True):
face_stylizer_options = face_stylizer.FaceStylizerOptions(
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
hparams=face_stylizer.HParams(epochs=1),
)
model = face_stylizer.FaceStylizer.create(
train_data=self._train_data, options=face_stylizer_options
)
self._evaluate_saved_model(model)
if __name__ == '__main__':
tf.test.main()

View File

@ -13,8 +13,15 @@
# limitations under the License. # limitations under the License.
"""Configurable model options for face stylizer models.""" """Configurable model options for face stylizer models."""
from typing import Sequence
import dataclasses import dataclasses
from typing import List
from mediapipe.model_maker.python.core.utils import loss_functions
def _default_perceptual_quality_loss_weight():
"""Default perceptual quality loss weight for face stylizer."""
return loss_functions.PerceptualLossWeight(l1=2.0, content=20.0, style=10.0)
# TODO: Add more detailed instructions about hyperparameter tuning. # TODO: Add more detailed instructions about hyperparameter tuning.
@ -26,12 +33,17 @@ class FaceStylizerModelOptions:
swap_layers: The layers of feature to be interpolated between encoding swap_layers: The layers of feature to be interpolated between encoding
features and StyleGAN input features. features and StyleGAN input features.
alpha: Weighting coefficient for swapping layer interpolation. alpha: Weighting coefficient for swapping layer interpolation.
adv_loss_weight: Weighting coeffcieint of adversarial loss versus perceptual perception_loss_weight: Weighting coefficients of image perception quality
loss. loss.
adv_loss_weight: Weighting coeffcieint of adversarial loss versus image
perceptual quality loss.
""" """
swap_layers: List[int] = dataclasses.field( swap_layers: Sequence[int] = dataclasses.field(
default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11] default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
) )
alpha: float = 1.0 alpha: float = 1.0
perception_loss_weight: loss_functions.PerceptualLossWeight = (
dataclasses.field(default_factory=_default_perceptual_quality_loss_weight)
)
adv_loss_weight: float = 1.0 adv_loss_weight: float = 1.0