From 56df724c36b53445862ac1e726220129eee981c0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 25 Apr 2023 00:45:26 -0700 Subject: [PATCH] Add customizable face stylizer module in MediaPipe model maker PiperOrigin-RevId: 526883862 --- .../python/vision/face_stylizer/BUILD | 42 +++- .../python/vision/face_stylizer/__init__.py | 14 ++ .../python/vision/face_stylizer/constants.py | 45 ++++ .../vision/face_stylizer/face_stylizer.py | 201 ++++++++++++++++++ .../face_stylizer/face_stylizer_test.py | 55 +++++ .../vision/face_stylizer/model_options.py | 18 +- 6 files changed, 369 insertions(+), 6 deletions(-) create mode 100644 mediapipe/model_maker/python/vision/face_stylizer/constants.py create mode 100644 mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py create mode 100644 mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py diff --git a/mediapipe/model_maker/python/vision/face_stylizer/BUILD b/mediapipe/model_maker/python/vision/face_stylizer/BUILD index b5e0399d1..53fb684e4 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/BUILD +++ b/mediapipe/model_maker/python/vision/face_stylizer/BUILD @@ -14,6 +14,7 @@ # Placeholder for internal Python strict test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python GPU test rule. licenses(["notice"]) @@ -26,6 +27,12 @@ filegroup( ]), ) +py_library( + name = "constants", + srcs = ["constants.py"], + deps = ["//mediapipe/model_maker/python/core/utils:file_util"], +) + py_library( name = "hyperparameters", srcs = ["hyperparameters.py"], @@ -37,6 +44,7 @@ py_library( py_library( name = "model_options", srcs = ["model_options.py"], + deps = ["//mediapipe/model_maker/python/core/utils:loss_functions"], ) py_library( @@ -72,11 +80,39 @@ py_library( py_test( name = "dataset_test", srcs = ["dataset_test.py"], - data = [ - ":testdata", - ], + data = [":testdata"], deps = [ ":dataset", "//mediapipe/tasks/python/test:test_utils", ], ) + +py_library( + name = "face_stylizer", + srcs = ["face_stylizer.py"], + deps = [ + ":constants", + ":face_stylizer_options", + ":hyperparameters", + ":model_options", + ":model_spec", + "//mediapipe/model_maker/python/core/data:classification_dataset", + "//mediapipe/model_maker/python/core/utils:loss_functions", + "//mediapipe/model_maker/python/core/utils:model_util", + "//mediapipe/model_maker/python/vision/core:image_preprocessing", + ], +) + +py_library( + name = "face_stylizer_import", + srcs = ["__init__.py"], + visibility = ["//visibility:public"], + deps = [ + ":dataset", + ":face_stylizer", + ":face_stylizer_options", + ":hyperparameters", + ":model_options", + ":model_spec", + ], +) diff --git a/mediapipe/model_maker/python/vision/face_stylizer/__init__.py b/mediapipe/model_maker/python/vision/face_stylizer/__init__.py index e935c0c76..3cec27964 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/__init__.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/__init__.py @@ -12,3 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. """MediaPipe Model Maker Python Public API For Face Stylization.""" + +from mediapipe.model_maker.python.vision.face_stylizer import dataset +from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer +from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_options +from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters +from mediapipe.model_maker.python.vision.face_stylizer import model_options +from mediapipe.model_maker.python.vision.face_stylizer import model_spec + +FaceStylizer = face_stylizer.FaceStylizer +SupportedModels = model_spec.SupportedModels +ModelOptions = model_options.FaceStylizerModelOptions +HParams = hyperparameters.HParams +Dataset = dataset.Dataset +FaceStylizerOptions = face_stylizer_options.FaceStylizerOptions diff --git a/mediapipe/model_maker/python/vision/face_stylizer/constants.py b/mediapipe/model_maker/python/vision/face_stylizer/constants.py new file mode 100644 index 000000000..e7a03aebd --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/constants.py @@ -0,0 +1,45 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Face stylizer model constants.""" + +from mediapipe.model_maker.python.core.utils import file_util + +# TODO: Move model files to GCS for downloading. +FACE_STYLIZER_ENCODER_MODEL_FILES = file_util.DownloadedFiles( + 'face_stylizer/encoder', + 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_encoder.tar.gz', + is_folder=True, +) +FACE_STYLIZER_DECODER_MODEL_FILES = file_util.DownloadedFiles( + 'face_stylizer/decoder', + 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_decoder.tar.gz', + is_folder=True, +) +FACE_STYLIZER_MAPPING_MODEL_FILES = file_util.DownloadedFiles( + 'face_stylizer/mapping', + 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_mapping.tar.gz', + is_folder=True, +) +FACE_STYLIZER_DISCRIMINATOR_MODEL_FILES = file_util.DownloadedFiles( + 'face_stylizer/discriminator', + 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_discriminator.tar.gz', + is_folder=True, +) +FACE_STYLIZER_W_FILES = file_util.DownloadedFiles( + 'face_stylizer/w_avg.npy', + 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy', +) + +# Dimension of the input style vector to the decoder +STYLE_DIM = 512 diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py new file mode 100644 index 000000000..3e850582f --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py @@ -0,0 +1,201 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""APIs to train face stylization model.""" + +from typing import Callable, Optional + +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds +from mediapipe.model_maker.python.core.utils import loss_functions +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.vision.core import image_preprocessing +from mediapipe.model_maker.python.vision.face_stylizer import constants +from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_options +from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp +from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt +from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms + + +class FaceStylizer(object): + """FaceStylizer for building face stylization model. + + Attributes: + w_avg: An average face latent code to regularize face generation in face + stylization. + """ + + def __init__( + self, + model_spec: ms.ModelSpec, + model_options: model_opt.FaceStylizerModelOptions, + hparams: hp.HParams, + ): + """Initializes face stylizer. + + Args: + model_spec: Specification for the model. + model_options: Model options for creating face stylizer. + hparams: The hyperparameters for training face stylizer. + """ + self._model_spec = model_spec + self._model_options = model_options + self._hparams = hparams + # TODO: Support face alignment in image preprocessor. + self._preprocessor = image_preprocessing.Preprocessor( + input_shape=self._model_spec.input_image_shape, + num_classes=1, + mean_rgb=self._model_spec.mean_rgb, + stddev_rgb=self._model_spec.stddev_rgb, + ) + + @classmethod + def create( + cls, + train_data: classification_ds.ClassificationDataset, + options: face_stylizer_options.FaceStylizerOptions, + ) -> 'FaceStylizer': + """Creates and trains a face stylizer with input datasets. + + Args: + train_data: The input style image dataset for training the face stylizer. + options: The options to configure face stylizer. + + Returns: + A FaceStylizer instant with the trained model. + """ + if options.model_options is None: + options.model_options = model_opt.FaceStylizerModelOptions() + + if options.hparams is None: + options.hparams = hp.HParams() + + spec = ms.SupportedModels.get(options.model) + + face_stylizer = cls( + model_spec=spec, + model_options=options.model_options, + hparams=options.hparams, + ) + face_stylizer._create_and_train_model(train_data) + return face_stylizer + + def _create_and_train_model( + self, train_data: classification_ds.ClassificationDataset + ): + """Creates and trains the face stylizer model. + + Args: + train_data: Training data. + """ + self._create_model() + self._train_model(train_data=train_data, preprocessor=self._preprocessor) + + def _create_model(self): + """Creates the componenets of face stylizer.""" + self._encoder = model_util.load_keras_model( + constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path() + ) + self._decoder = model_util.load_keras_model( + constants.FACE_STYLIZER_DECODER_MODEL_FILES.get_path() + ) + self._mapping_network = model_util.load_keras_model( + constants.FACE_STYLIZER_MAPPING_MODEL_FILES.get_path() + ) + self._discriminator = model_util.load_keras_model( + constants.FACE_STYLIZER_DISCRIMINATOR_MODEL_FILES.get_path() + ) + with tf.io.gfile.GFile( + constants.FACE_STYLIZER_W_FILES.get_path(), 'rb' + ) as f: + w_avg = np.load(f) + + self.w_avg = w_avg[: self._model_spec.style_block_num][np.newaxis] + + def _train_model( + self, + train_data: classification_ds.ClassificationDataset, + preprocessor: Optional[Callable[..., bool]] = None, + ): + """Trains the face stylizer model. + + Args: + train_data: The data for training model. + preprocessor: The image preprocessor. + """ + train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor) + + # TODO: Support processing mulitple input style images. The + # input style images are expected to have similar style. + # style_sample represents a tuple of (style_image, style_label). + style_sample = next(iter(train_dataset)) + style_img = style_sample[0] + + batch_size = self._hparams.batch_size + label_in = tf.zeros(shape=[batch_size, 0]) + + style_encoding = self._encoder(style_img) + + optimizer = tf.keras.optimizers.Adam( + learning_rate=self._hparams.learning_rate, + beta_1=self._hparams.beta_1, + beta_2=self._hparams.beta_2, + ) + + image_perceptual_quality_loss = loss_functions.ImagePerceptualQualityLoss( + loss_weight=self._model_options.perception_loss_weight + ) + + for i in range(self._hparams.epochs): + noise = tf.random.normal(shape=[batch_size, constants.STYLE_DIM]) + + mean_w = self._mapping_network([noise, label_in], training=False)[ + :, : self._model_spec.style_block_num + ] + style_encodings = tf.tile(style_encoding, [batch_size, 1, 1]) + + in_latent = tf.Variable(tf.identity(style_encodings)) + + alpha = self._model_options.alpha + for swap_layer in self._model_options.swap_layers: + in_latent = in_latent[:, swap_layer].assign( + alpha * style_encodings[:, swap_layer] + + (1 - alpha) * mean_w[:, swap_layer] + ) + + with tf.GradientTape() as tape: + outputs = self._decoder( + {'inputs': in_latent + self.w_avg}, + training=False, + ) + gen_img = outputs['image'][-1] + + real_feature = self._discriminator( + [tf.transpose(style_img, [0, 3, 1, 2]), label_in] + ) + gen_feature = self._discriminator( + [tf.transpose(gen_img, [0, 3, 1, 2]), label_in] + ) + + style_loss = image_perceptual_quality_loss(gen_img, style_img) + style_loss += ( + tf.keras.losses.MeanAbsoluteError()(real_feature, gen_feature) + * self._model_options.adv_loss_weight + ) + tf.compat.v1.logging.info(f'Iteration {i} loss: {style_loss.numpy()}') + + tvars = self._decoder.trainable_variables + grads = tape.gradient(style_loss, tvars) + optimizer.apply_gradients(list(zip(grads, tvars))) diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py new file mode 100644 index 000000000..c244a3ee4 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py @@ -0,0 +1,55 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from mediapipe.model_maker.python.vision import face_stylizer +from mediapipe.tasks.python.test import test_utils + + +class FaceStylizerTest(tf.test.TestCase): + + def _load_data(self): + """Loads training dataset.""" + input_data_dir = test_utils.get_test_data_path('testdata') + + data = face_stylizer.Dataset.from_folder(dirname=input_data_dir) + return data + + def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer): + """Evaluates the fine-tuned face stylizer model.""" + test_image = tf.ones(shape=(256, 256, 3), dtype=tf.float32) + test_image_batch = test_image[tf.newaxis] + in_latent = model._encoder(test_image_batch) + output = model._decoder({'inputs': in_latent + model.w_avg}) + self.assertEqual(output['image'][-1].shape, (1, 256, 256, 3)) + + def setUp(self): + super().setUp() + self._train_data = self._load_data() + + def test_finetuning_face_stylizer_with_single_input_style_image(self): + with self.test_session(use_gpu=True): + face_stylizer_options = face_stylizer.FaceStylizerOptions( + model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256, + hparams=face_stylizer.HParams(epochs=1), + ) + model = face_stylizer.FaceStylizer.create( + train_data=self._train_data, options=face_stylizer_options + ) + self._evaluate_saved_model(model) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/face_stylizer/model_options.py b/mediapipe/model_maker/python/vision/face_stylizer/model_options.py index 064e2d027..ff5a54eba 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/model_options.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/model_options.py @@ -13,8 +13,15 @@ # limitations under the License. """Configurable model options for face stylizer models.""" +from typing import Sequence import dataclasses -from typing import List + +from mediapipe.model_maker.python.core.utils import loss_functions + + +def _default_perceptual_quality_loss_weight(): + """Default perceptual quality loss weight for face stylizer.""" + return loss_functions.PerceptualLossWeight(l1=2.0, content=20.0, style=10.0) # TODO: Add more detailed instructions about hyperparameter tuning. @@ -26,12 +33,17 @@ class FaceStylizerModelOptions: swap_layers: The layers of feature to be interpolated between encoding features and StyleGAN input features. alpha: Weighting coefficient for swapping layer interpolation. - adv_loss_weight: Weighting coeffcieint of adversarial loss versus perceptual + perception_loss_weight: Weighting coefficients of image perception quality loss. + adv_loss_weight: Weighting coeffcieint of adversarial loss versus image + perceptual quality loss. """ - swap_layers: List[int] = dataclasses.field( + swap_layers: Sequence[int] = dataclasses.field( default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11] ) alpha: float = 1.0 + perception_loss_weight: loss_functions.PerceptualLossWeight = ( + dataclasses.field(default_factory=_default_perceptual_quality_loss_weight) + ) adv_loss_weight: float = 1.0