Add customizable face stylizer module in MediaPipe model maker
PiperOrigin-RevId: 526883862
This commit is contained in:
parent
a0eb1b696c
commit
56df724c36
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
# Placeholder for internal Python strict test compatibility macro.
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
# Placeholder for internal Python strict library and test compatibility macro.
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
# Placeholder for internal Python GPU test rule.
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
@ -26,6 +27,12 @@ filegroup(
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "constants",
|
||||||
|
srcs = ["constants.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/core/utils:file_util"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "hyperparameters",
|
name = "hyperparameters",
|
||||||
srcs = ["hyperparameters.py"],
|
srcs = ["hyperparameters.py"],
|
||||||
|
@ -37,6 +44,7 @@ py_library(
|
||||||
py_library(
|
py_library(
|
||||||
name = "model_options",
|
name = "model_options",
|
||||||
srcs = ["model_options.py"],
|
srcs = ["model_options.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/core/utils:loss_functions"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -72,11 +80,39 @@ py_library(
|
||||||
py_test(
|
py_test(
|
||||||
name = "dataset_test",
|
name = "dataset_test",
|
||||||
srcs = ["dataset_test.py"],
|
srcs = ["dataset_test.py"],
|
||||||
data = [
|
data = [":testdata"],
|
||||||
":testdata",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "face_stylizer",
|
||||||
|
srcs = ["face_stylizer.py"],
|
||||||
|
deps = [
|
||||||
|
":constants",
|
||||||
|
":face_stylizer_options",
|
||||||
|
":hyperparameters",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "face_stylizer_import",
|
||||||
|
srcs = ["__init__.py"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":face_stylizer",
|
||||||
|
":face_stylizer_options",
|
||||||
|
":hyperparameters",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -12,3 +12,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""MediaPipe Model Maker Python Public API For Face Stylization."""
|
"""MediaPipe Model Maker Python Public API For Face Stylization."""
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_options
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import model_options
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import model_spec
|
||||||
|
|
||||||
|
FaceStylizer = face_stylizer.FaceStylizer
|
||||||
|
SupportedModels = model_spec.SupportedModels
|
||||||
|
ModelOptions = model_options.FaceStylizerModelOptions
|
||||||
|
HParams = hyperparameters.HParams
|
||||||
|
Dataset = dataset.Dataset
|
||||||
|
FaceStylizerOptions = face_stylizer_options.FaceStylizerOptions
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Face stylizer model constants."""
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
|
|
||||||
|
# TODO: Move model files to GCS for downloading.
|
||||||
|
FACE_STYLIZER_ENCODER_MODEL_FILES = file_util.DownloadedFiles(
|
||||||
|
'face_stylizer/encoder',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_encoder.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
FACE_STYLIZER_DECODER_MODEL_FILES = file_util.DownloadedFiles(
|
||||||
|
'face_stylizer/decoder',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_decoder.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
FACE_STYLIZER_MAPPING_MODEL_FILES = file_util.DownloadedFiles(
|
||||||
|
'face_stylizer/mapping',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_mapping.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
FACE_STYLIZER_DISCRIMINATOR_MODEL_FILES = file_util.DownloadedFiles(
|
||||||
|
'face_stylizer/discriminator',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_discriminator.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
|
||||||
|
'face_stylizer/w_avg.npy',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dimension of the input style vector to the decoder
|
||||||
|
STYLE_DIM = 512
|
|
@ -0,0 +1,201 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""APIs to train face stylization model."""
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||||
|
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from mediapipe.model_maker.python.vision.core import image_preprocessing
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import constants
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_options
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
|
||||||
|
|
||||||
|
|
||||||
|
class FaceStylizer(object):
|
||||||
|
"""FaceStylizer for building face stylization model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
w_avg: An average face latent code to regularize face generation in face
|
||||||
|
stylization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_spec: ms.ModelSpec,
|
||||||
|
model_options: model_opt.FaceStylizerModelOptions,
|
||||||
|
hparams: hp.HParams,
|
||||||
|
):
|
||||||
|
"""Initializes face stylizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_spec: Specification for the model.
|
||||||
|
model_options: Model options for creating face stylizer.
|
||||||
|
hparams: The hyperparameters for training face stylizer.
|
||||||
|
"""
|
||||||
|
self._model_spec = model_spec
|
||||||
|
self._model_options = model_options
|
||||||
|
self._hparams = hparams
|
||||||
|
# TODO: Support face alignment in image preprocessor.
|
||||||
|
self._preprocessor = image_preprocessing.Preprocessor(
|
||||||
|
input_shape=self._model_spec.input_image_shape,
|
||||||
|
num_classes=1,
|
||||||
|
mean_rgb=self._model_spec.mean_rgb,
|
||||||
|
stddev_rgb=self._model_spec.stddev_rgb,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
train_data: classification_ds.ClassificationDataset,
|
||||||
|
options: face_stylizer_options.FaceStylizerOptions,
|
||||||
|
) -> 'FaceStylizer':
|
||||||
|
"""Creates and trains a face stylizer with input datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: The input style image dataset for training the face stylizer.
|
||||||
|
options: The options to configure face stylizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A FaceStylizer instant with the trained model.
|
||||||
|
"""
|
||||||
|
if options.model_options is None:
|
||||||
|
options.model_options = model_opt.FaceStylizerModelOptions()
|
||||||
|
|
||||||
|
if options.hparams is None:
|
||||||
|
options.hparams = hp.HParams()
|
||||||
|
|
||||||
|
spec = ms.SupportedModels.get(options.model)
|
||||||
|
|
||||||
|
face_stylizer = cls(
|
||||||
|
model_spec=spec,
|
||||||
|
model_options=options.model_options,
|
||||||
|
hparams=options.hparams,
|
||||||
|
)
|
||||||
|
face_stylizer._create_and_train_model(train_data)
|
||||||
|
return face_stylizer
|
||||||
|
|
||||||
|
def _create_and_train_model(
|
||||||
|
self, train_data: classification_ds.ClassificationDataset
|
||||||
|
):
|
||||||
|
"""Creates and trains the face stylizer model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
"""
|
||||||
|
self._create_model()
|
||||||
|
self._train_model(train_data=train_data, preprocessor=self._preprocessor)
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
"""Creates the componenets of face stylizer."""
|
||||||
|
self._encoder = model_util.load_keras_model(
|
||||||
|
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
|
||||||
|
)
|
||||||
|
self._decoder = model_util.load_keras_model(
|
||||||
|
constants.FACE_STYLIZER_DECODER_MODEL_FILES.get_path()
|
||||||
|
)
|
||||||
|
self._mapping_network = model_util.load_keras_model(
|
||||||
|
constants.FACE_STYLIZER_MAPPING_MODEL_FILES.get_path()
|
||||||
|
)
|
||||||
|
self._discriminator = model_util.load_keras_model(
|
||||||
|
constants.FACE_STYLIZER_DISCRIMINATOR_MODEL_FILES.get_path()
|
||||||
|
)
|
||||||
|
with tf.io.gfile.GFile(
|
||||||
|
constants.FACE_STYLIZER_W_FILES.get_path(), 'rb'
|
||||||
|
) as f:
|
||||||
|
w_avg = np.load(f)
|
||||||
|
|
||||||
|
self.w_avg = w_avg[: self._model_spec.style_block_num][np.newaxis]
|
||||||
|
|
||||||
|
def _train_model(
|
||||||
|
self,
|
||||||
|
train_data: classification_ds.ClassificationDataset,
|
||||||
|
preprocessor: Optional[Callable[..., bool]] = None,
|
||||||
|
):
|
||||||
|
"""Trains the face stylizer model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: The data for training model.
|
||||||
|
preprocessor: The image preprocessor.
|
||||||
|
"""
|
||||||
|
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
|
||||||
|
|
||||||
|
# TODO: Support processing mulitple input style images. The
|
||||||
|
# input style images are expected to have similar style.
|
||||||
|
# style_sample represents a tuple of (style_image, style_label).
|
||||||
|
style_sample = next(iter(train_dataset))
|
||||||
|
style_img = style_sample[0]
|
||||||
|
|
||||||
|
batch_size = self._hparams.batch_size
|
||||||
|
label_in = tf.zeros(shape=[batch_size, 0])
|
||||||
|
|
||||||
|
style_encoding = self._encoder(style_img)
|
||||||
|
|
||||||
|
optimizer = tf.keras.optimizers.Adam(
|
||||||
|
learning_rate=self._hparams.learning_rate,
|
||||||
|
beta_1=self._hparams.beta_1,
|
||||||
|
beta_2=self._hparams.beta_2,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_perceptual_quality_loss = loss_functions.ImagePerceptualQualityLoss(
|
||||||
|
loss_weight=self._model_options.perception_loss_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(self._hparams.epochs):
|
||||||
|
noise = tf.random.normal(shape=[batch_size, constants.STYLE_DIM])
|
||||||
|
|
||||||
|
mean_w = self._mapping_network([noise, label_in], training=False)[
|
||||||
|
:, : self._model_spec.style_block_num
|
||||||
|
]
|
||||||
|
style_encodings = tf.tile(style_encoding, [batch_size, 1, 1])
|
||||||
|
|
||||||
|
in_latent = tf.Variable(tf.identity(style_encodings))
|
||||||
|
|
||||||
|
alpha = self._model_options.alpha
|
||||||
|
for swap_layer in self._model_options.swap_layers:
|
||||||
|
in_latent = in_latent[:, swap_layer].assign(
|
||||||
|
alpha * style_encodings[:, swap_layer]
|
||||||
|
+ (1 - alpha) * mean_w[:, swap_layer]
|
||||||
|
)
|
||||||
|
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
outputs = self._decoder(
|
||||||
|
{'inputs': in_latent + self.w_avg},
|
||||||
|
training=False,
|
||||||
|
)
|
||||||
|
gen_img = outputs['image'][-1]
|
||||||
|
|
||||||
|
real_feature = self._discriminator(
|
||||||
|
[tf.transpose(style_img, [0, 3, 1, 2]), label_in]
|
||||||
|
)
|
||||||
|
gen_feature = self._discriminator(
|
||||||
|
[tf.transpose(gen_img, [0, 3, 1, 2]), label_in]
|
||||||
|
)
|
||||||
|
|
||||||
|
style_loss = image_perceptual_quality_loss(gen_img, style_img)
|
||||||
|
style_loss += (
|
||||||
|
tf.keras.losses.MeanAbsoluteError()(real_feature, gen_feature)
|
||||||
|
* self._model_options.adv_loss_weight
|
||||||
|
)
|
||||||
|
tf.compat.v1.logging.info(f'Iteration {i} loss: {style_loss.numpy()}')
|
||||||
|
|
||||||
|
tvars = self._decoder.trainable_variables
|
||||||
|
grads = tape.gradient(style_loss, tvars)
|
||||||
|
optimizer.apply_gradients(list(zip(grads, tvars)))
|
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision import face_stylizer
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
class FaceStylizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _load_data(self):
|
||||||
|
"""Loads training dataset."""
|
||||||
|
input_data_dir = test_utils.get_test_data_path('testdata')
|
||||||
|
|
||||||
|
data = face_stylizer.Dataset.from_folder(dirname=input_data_dir)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
|
||||||
|
"""Evaluates the fine-tuned face stylizer model."""
|
||||||
|
test_image = tf.ones(shape=(256, 256, 3), dtype=tf.float32)
|
||||||
|
test_image_batch = test_image[tf.newaxis]
|
||||||
|
in_latent = model._encoder(test_image_batch)
|
||||||
|
output = model._decoder({'inputs': in_latent + model.w_avg})
|
||||||
|
self.assertEqual(output['image'][-1].shape, (1, 256, 256, 3))
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self._train_data = self._load_data()
|
||||||
|
|
||||||
|
def test_finetuning_face_stylizer_with_single_input_style_image(self):
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
face_stylizer_options = face_stylizer.FaceStylizerOptions(
|
||||||
|
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
|
||||||
|
hparams=face_stylizer.HParams(epochs=1),
|
||||||
|
)
|
||||||
|
model = face_stylizer.FaceStylizer.create(
|
||||||
|
train_data=self._train_data, options=face_stylizer_options
|
||||||
|
)
|
||||||
|
self._evaluate_saved_model(model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -13,8 +13,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Configurable model options for face stylizer models."""
|
"""Configurable model options for face stylizer models."""
|
||||||
|
|
||||||
|
from typing import Sequence
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import List
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
|
|
||||||
|
|
||||||
|
def _default_perceptual_quality_loss_weight():
|
||||||
|
"""Default perceptual quality loss weight for face stylizer."""
|
||||||
|
return loss_functions.PerceptualLossWeight(l1=2.0, content=20.0, style=10.0)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Add more detailed instructions about hyperparameter tuning.
|
# TODO: Add more detailed instructions about hyperparameter tuning.
|
||||||
|
@ -26,12 +33,17 @@ class FaceStylizerModelOptions:
|
||||||
swap_layers: The layers of feature to be interpolated between encoding
|
swap_layers: The layers of feature to be interpolated between encoding
|
||||||
features and StyleGAN input features.
|
features and StyleGAN input features.
|
||||||
alpha: Weighting coefficient for swapping layer interpolation.
|
alpha: Weighting coefficient for swapping layer interpolation.
|
||||||
adv_loss_weight: Weighting coeffcieint of adversarial loss versus perceptual
|
perception_loss_weight: Weighting coefficients of image perception quality
|
||||||
loss.
|
loss.
|
||||||
|
adv_loss_weight: Weighting coeffcieint of adversarial loss versus image
|
||||||
|
perceptual quality loss.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
swap_layers: List[int] = dataclasses.field(
|
swap_layers: Sequence[int] = dataclasses.field(
|
||||||
default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
|
default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
|
||||||
)
|
)
|
||||||
alpha: float = 1.0
|
alpha: float = 1.0
|
||||||
|
perception_loss_weight: loss_functions.PerceptualLossWeight = (
|
||||||
|
dataclasses.field(default_factory=_default_perceptual_quality_loss_weight)
|
||||||
|
)
|
||||||
adv_loss_weight: float = 1.0
|
adv_loss_weight: float = 1.0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user