From 1eeb89e95fa10777081de1f291882cf90981776c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 30 Mar 2023 16:04:47 -0700 Subject: [PATCH] Add the model configuration and training hyperparameters for BlazeFaceStylizer. PiperOrigin-RevId: 520767282 --- .../python/vision/face_stylizer/BUILD | 34 ++++++++++ .../face_stylizer/face_stylizer_options.py | 36 +++++++++++ .../vision/face_stylizer/hyperparameters.py | 39 ++++++++++++ .../vision/face_stylizer/model_options.py | 37 +++++++++++ .../python/vision/face_stylizer/model_spec.py | 63 +++++++++++++++++++ .../vision/face_stylizer/model_spec_test.py | 44 +++++++++++++ 6 files changed, 253 insertions(+) create mode 100644 mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_options.py create mode 100644 mediapipe/model_maker/python/vision/face_stylizer/hyperparameters.py create mode 100644 mediapipe/model_maker/python/vision/face_stylizer/model_options.py create mode 100644 mediapipe/model_maker/python/vision/face_stylizer/model_spec.py create mode 100644 mediapipe/model_maker/python/vision/face_stylizer/model_spec_test.py diff --git a/mediapipe/model_maker/python/vision/face_stylizer/BUILD b/mediapipe/model_maker/python/vision/face_stylizer/BUILD index 804511540..b5e0399d1 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/BUILD +++ b/mediapipe/model_maker/python/vision/face_stylizer/BUILD @@ -26,6 +26,40 @@ filegroup( ]), ) +py_library( + name = "hyperparameters", + srcs = ["hyperparameters.py"], + deps = [ + "//mediapipe/model_maker/python/core:hyperparameters", + ], +) + +py_library( + name = "model_options", + srcs = ["model_options.py"], +) + +py_library( + name = "model_spec", + srcs = ["model_spec.py"], +) + +py_test( + name = "model_spec_test", + srcs = ["model_spec_test.py"], + deps = [":model_spec"], +) + +py_library( + name = "face_stylizer_options", + srcs = ["face_stylizer_options.py"], + deps = [ + ":hyperparameters", + ":model_options", + ":model_spec", + ], +) + py_library( name = "dataset", srcs = ["dataset.py"], diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_options.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_options.py new file mode 100644 index 000000000..e0e2441d1 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_options.py @@ -0,0 +1,36 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Options for building face stylizer.""" + +import dataclasses +from typing import Optional + +from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters +from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt +from mediapipe.model_maker.python.vision.face_stylizer import model_spec + + +@dataclasses.dataclass +class FaceStylizerOptions: + """Configurable options for building face stylizer. + + Attributes: + model: A model from the SupportedModels enum. + hparams: A set of hyperparameters used to train the face stylizer. + model_options: A set of options for configuring the model. + """ + + model: model_spec.SupportedModels + model_options: Optional[model_opt.FaceStylizerModelOptions] = None + hparams: Optional[hyperparameters.HParams] = None diff --git a/mediapipe/model_maker/python/vision/face_stylizer/hyperparameters.py b/mediapipe/model_maker/python/vision/face_stylizer/hyperparameters.py new file mode 100644 index 000000000..0a129a721 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/hyperparameters.py @@ -0,0 +1,39 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hyperparameters for training on-device face stylization models.""" + +import dataclasses + +from mediapipe.model_maker.python.core import hyperparameters as hp + + +@dataclasses.dataclass +class HParams(hp.BaseHParams): + """The hyperparameters for training face stylizers. + + Attributes: + learning_rate: Learning rate to use for gradient descent training. + batch_size: Batch size for training. + epochs: Number of training iterations. + beta_1: beta_1 used in tf.keras.optimizers.Adam. + beta_2: beta_2 used in tf.keras.optimizers.Adam. + """ + + # Parameters from BaseHParams class. + learning_rate: float = 5e-5 + batch_size: int = 4 + epochs: int = 100 + # Parameters for face stylizer. + beta_1 = 0.0 + beta_2 = 0.99 diff --git a/mediapipe/model_maker/python/vision/face_stylizer/model_options.py b/mediapipe/model_maker/python/vision/face_stylizer/model_options.py new file mode 100644 index 000000000..064e2d027 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/model_options.py @@ -0,0 +1,37 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configurable model options for face stylizer models.""" + +import dataclasses +from typing import List + + +# TODO: Add more detailed instructions about hyperparameter tuning. +@dataclasses.dataclass +class FaceStylizerModelOptions: + """Configurable model options for face stylizer models. + + Attributes: + swap_layers: The layers of feature to be interpolated between encoding + features and StyleGAN input features. + alpha: Weighting coefficient for swapping layer interpolation. + adv_loss_weight: Weighting coeffcieint of adversarial loss versus perceptual + loss. + """ + + swap_layers: List[int] = dataclasses.field( + default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11] + ) + alpha: float = 1.0 + adv_loss_weight: float = 1.0 diff --git a/mediapipe/model_maker/python/vision/face_stylizer/model_spec.py b/mediapipe/model_maker/python/vision/face_stylizer/model_spec.py new file mode 100644 index 000000000..6f5126f0b --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/model_spec.py @@ -0,0 +1,63 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Face stylizer model specification.""" + +import enum +import functools +from typing import List + + +class ModelSpec(object): + """Specification of face stylizer model.""" + + mean_rgb = [127.5] + stddev_rgb = [127.5] + + def __init__( + self, style_block_num: int, input_image_shape: List[int], name: str = '' + ): + """Initializes a new instance of the `ModelSpec` class for face stylizer. + + Args: + style_block_num: int, number of style block in the decoder. + input_image_shape: list of int, input image shape. + name: str, model spec name. + """ + self.style_block_num = style_block_num + self.input_image_shape = input_image_shape + self.name = name + + +blaze_face_stylizer_256_spec = functools.partial( + ModelSpec, + style_block_num=12, + input_image_shape=[256, 256], + name='blaze_face_stylizer_256', +) + + +# TODO: Document the exposed models. +@enum.unique +class SupportedModels(enum.Enum): + """Face stylizer model supported by MediaPipe model maker.""" + + BLAZE_FACE_STYLIZER_256 = blaze_face_stylizer_256_spec + + @classmethod + def get(cls, spec: 'SupportedModels') -> 'ModelSpec': + """Gets model spec from the input enum and initializes it.""" + if spec not in cls: + raise TypeError('Unsupported face stylizer spec: {}'.format(spec)) + + return spec.value() diff --git a/mediapipe/model_maker/python/vision/face_stylizer/model_spec_test.py b/mediapipe/model_maker/python/vision/face_stylizer/model_spec_test.py new file mode 100644 index 000000000..8be3242ac --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/model_spec_test.py @@ -0,0 +1,44 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tensorflow as tf + +from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms + + +class ModelSpecTest(tf.test.TestCase): + + def test_predefine_spec(self): + blaze_face_stylizer_256_spec = ms.blaze_face_stylizer_256_spec() + self.assertIsInstance(blaze_face_stylizer_256_spec, ms.ModelSpec) + self.assertEqual(blaze_face_stylizer_256_spec.style_block_num, 12) + self.assertAllEqual( + blaze_face_stylizer_256_spec.input_image_shape, [256, 256] + ) + self.assertEqual( + blaze_face_stylizer_256_spec.name, 'blaze_face_stylizer_256' + ) + + def test_predefine_spec_enum(self): + blaze_face_stylizer_256 = ms.SupportedModels.BLAZE_FACE_STYLIZER_256 + spec = ms.SupportedModels.get(blaze_face_stylizer_256) + self.assertIsInstance(spec, ms.ModelSpec) + self.assertEqual(spec.style_block_num, 12) + self.assertAllEqual(spec.input_image_shape, [256, 256]) + self.assertEqual(spec.name, 'blaze_face_stylizer_256') + + +if __name__ == '__main__': + tf.test.main()