Add the model configuration and training hyperparameters for BlazeFaceStylizer.

PiperOrigin-RevId: 520767282
This commit is contained in:
MediaPipe Team 2023-03-30 16:04:47 -07:00 committed by Copybara-Service
parent a4923ca7aa
commit 1eeb89e95f
6 changed files with 253 additions and 0 deletions

View File

@ -26,6 +26,40 @@ filegroup(
]),
)
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
deps = [
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "model_options",
srcs = ["model_options.py"],
)
py_library(
name = "model_spec",
srcs = ["model_spec.py"],
)
py_test(
name = "model_spec_test",
srcs = ["model_spec_test.py"],
deps = [":model_spec"],
)
py_library(
name = "face_stylizer_options",
srcs = ["face_stylizer_options.py"],
deps = [
":hyperparameters",
":model_options",
":model_spec",
],
)
py_library(
name = "dataset",
srcs = ["dataset.py"],

View File

@ -0,0 +1,36 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Options for building face stylizer."""
import dataclasses
from typing import Optional
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
from mediapipe.model_maker.python.vision.face_stylizer import model_spec
@dataclasses.dataclass
class FaceStylizerOptions:
"""Configurable options for building face stylizer.
Attributes:
model: A model from the SupportedModels enum.
hparams: A set of hyperparameters used to train the face stylizer.
model_options: A set of options for configuring the model.
"""
model: model_spec.SupportedModels
model_options: Optional[model_opt.FaceStylizerModelOptions] = None
hparams: Optional[hyperparameters.HParams] = None

View File

@ -0,0 +1,39 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training on-device face stylization models."""
import dataclasses
from mediapipe.model_maker.python.core import hyperparameters as hp
@dataclasses.dataclass
class HParams(hp.BaseHParams):
"""The hyperparameters for training face stylizers.
Attributes:
learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations.
beta_1: beta_1 used in tf.keras.optimizers.Adam.
beta_2: beta_2 used in tf.keras.optimizers.Adam.
"""
# Parameters from BaseHParams class.
learning_rate: float = 5e-5
batch_size: int = 4
epochs: int = 100
# Parameters for face stylizer.
beta_1 = 0.0
beta_2 = 0.99

View File

@ -0,0 +1,37 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurable model options for face stylizer models."""
import dataclasses
from typing import List
# TODO: Add more detailed instructions about hyperparameter tuning.
@dataclasses.dataclass
class FaceStylizerModelOptions:
"""Configurable model options for face stylizer models.
Attributes:
swap_layers: The layers of feature to be interpolated between encoding
features and StyleGAN input features.
alpha: Weighting coefficient for swapping layer interpolation.
adv_loss_weight: Weighting coeffcieint of adversarial loss versus perceptual
loss.
"""
swap_layers: List[int] = dataclasses.field(
default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
)
alpha: float = 1.0
adv_loss_weight: float = 1.0

View File

@ -0,0 +1,63 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Face stylizer model specification."""
import enum
import functools
from typing import List
class ModelSpec(object):
"""Specification of face stylizer model."""
mean_rgb = [127.5]
stddev_rgb = [127.5]
def __init__(
self, style_block_num: int, input_image_shape: List[int], name: str = ''
):
"""Initializes a new instance of the `ModelSpec` class for face stylizer.
Args:
style_block_num: int, number of style block in the decoder.
input_image_shape: list of int, input image shape.
name: str, model spec name.
"""
self.style_block_num = style_block_num
self.input_image_shape = input_image_shape
self.name = name
blaze_face_stylizer_256_spec = functools.partial(
ModelSpec,
style_block_num=12,
input_image_shape=[256, 256],
name='blaze_face_stylizer_256',
)
# TODO: Document the exposed models.
@enum.unique
class SupportedModels(enum.Enum):
"""Face stylizer model supported by MediaPipe model maker."""
BLAZE_FACE_STYLIZER_256 = blaze_face_stylizer_256_spec
@classmethod
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
"""Gets model spec from the input enum and initializes it."""
if spec not in cls:
raise TypeError('Unsupported face stylizer spec: {}'.format(spec))
return spec.value()

View File

@ -0,0 +1,44 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
class ModelSpecTest(tf.test.TestCase):
def test_predefine_spec(self):
blaze_face_stylizer_256_spec = ms.blaze_face_stylizer_256_spec()
self.assertIsInstance(blaze_face_stylizer_256_spec, ms.ModelSpec)
self.assertEqual(blaze_face_stylizer_256_spec.style_block_num, 12)
self.assertAllEqual(
blaze_face_stylizer_256_spec.input_image_shape, [256, 256]
)
self.assertEqual(
blaze_face_stylizer_256_spec.name, 'blaze_face_stylizer_256'
)
def test_predefine_spec_enum(self):
blaze_face_stylizer_256 = ms.SupportedModels.BLAZE_FACE_STYLIZER_256
spec = ms.SupportedModels.get(blaze_face_stylizer_256)
self.assertIsInstance(spec, ms.ModelSpec)
self.assertEqual(spec.style_block_num, 12)
self.assertAllEqual(spec.input_image_shape, [256, 256])
self.assertEqual(spec.name, 'blaze_face_stylizer_256')
if __name__ == '__main__':
tf.test.main()