Add the model configuration and training hyperparameters for BlazeFaceStylizer.
PiperOrigin-RevId: 520767282
This commit is contained in:
parent
a4923ca7aa
commit
1eeb89e95f
|
@ -26,6 +26,40 @@ filegroup(
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "hyperparameters",
|
||||||
|
srcs = ["hyperparameters.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_options",
|
||||||
|
srcs = ["model_options.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_spec",
|
||||||
|
srcs = ["model_spec.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "model_spec_test",
|
||||||
|
srcs = ["model_spec_test.py"],
|
||||||
|
deps = [":model_spec"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "face_stylizer_options",
|
||||||
|
srcs = ["face_stylizer_options.py"],
|
||||||
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "dataset",
|
name = "dataset",
|
||||||
srcs = ["dataset.py"],
|
srcs = ["dataset.py"],
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Options for building face stylizer."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import model_spec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class FaceStylizerOptions:
|
||||||
|
"""Configurable options for building face stylizer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: A model from the SupportedModels enum.
|
||||||
|
hparams: A set of hyperparameters used to train the face stylizer.
|
||||||
|
model_options: A set of options for configuring the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: model_spec.SupportedModels
|
||||||
|
model_options: Optional[model_opt.FaceStylizerModelOptions] = None
|
||||||
|
hparams: Optional[hyperparameters.HParams] = None
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Hyperparameters for training on-device face stylization models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class HParams(hp.BaseHParams):
|
||||||
|
"""The hyperparameters for training face stylizers.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate: Learning rate to use for gradient descent training.
|
||||||
|
batch_size: Batch size for training.
|
||||||
|
epochs: Number of training iterations.
|
||||||
|
beta_1: beta_1 used in tf.keras.optimizers.Adam.
|
||||||
|
beta_2: beta_2 used in tf.keras.optimizers.Adam.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Parameters from BaseHParams class.
|
||||||
|
learning_rate: float = 5e-5
|
||||||
|
batch_size: int = 4
|
||||||
|
epochs: int = 100
|
||||||
|
# Parameters for face stylizer.
|
||||||
|
beta_1 = 0.0
|
||||||
|
beta_2 = 0.99
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Configurable model options for face stylizer models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Add more detailed instructions about hyperparameter tuning.
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class FaceStylizerModelOptions:
|
||||||
|
"""Configurable model options for face stylizer models.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
swap_layers: The layers of feature to be interpolated between encoding
|
||||||
|
features and StyleGAN input features.
|
||||||
|
alpha: Weighting coefficient for swapping layer interpolation.
|
||||||
|
adv_loss_weight: Weighting coeffcieint of adversarial loss versus perceptual
|
||||||
|
loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
swap_layers: List[int] = dataclasses.field(
|
||||||
|
default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
|
||||||
|
)
|
||||||
|
alpha: float = 1.0
|
||||||
|
adv_loss_weight: float = 1.0
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Face stylizer model specification."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import functools
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSpec(object):
|
||||||
|
"""Specification of face stylizer model."""
|
||||||
|
|
||||||
|
mean_rgb = [127.5]
|
||||||
|
stddev_rgb = [127.5]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, style_block_num: int, input_image_shape: List[int], name: str = ''
|
||||||
|
):
|
||||||
|
"""Initializes a new instance of the `ModelSpec` class for face stylizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
style_block_num: int, number of style block in the decoder.
|
||||||
|
input_image_shape: list of int, input image shape.
|
||||||
|
name: str, model spec name.
|
||||||
|
"""
|
||||||
|
self.style_block_num = style_block_num
|
||||||
|
self.input_image_shape = input_image_shape
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
|
blaze_face_stylizer_256_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
style_block_num=12,
|
||||||
|
input_image_shape=[256, 256],
|
||||||
|
name='blaze_face_stylizer_256',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Document the exposed models.
|
||||||
|
@enum.unique
|
||||||
|
class SupportedModels(enum.Enum):
|
||||||
|
"""Face stylizer model supported by MediaPipe model maker."""
|
||||||
|
|
||||||
|
BLAZE_FACE_STYLIZER_256 = blaze_face_stylizer_256_spec
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||||
|
"""Gets model spec from the input enum and initializes it."""
|
||||||
|
if spec not in cls:
|
||||||
|
raise TypeError('Unsupported face stylizer spec: {}'.format(spec))
|
||||||
|
|
||||||
|
return spec.value()
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSpecTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_predefine_spec(self):
|
||||||
|
blaze_face_stylizer_256_spec = ms.blaze_face_stylizer_256_spec()
|
||||||
|
self.assertIsInstance(blaze_face_stylizer_256_spec, ms.ModelSpec)
|
||||||
|
self.assertEqual(blaze_face_stylizer_256_spec.style_block_num, 12)
|
||||||
|
self.assertAllEqual(
|
||||||
|
blaze_face_stylizer_256_spec.input_image_shape, [256, 256]
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
blaze_face_stylizer_256_spec.name, 'blaze_face_stylizer_256'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_predefine_spec_enum(self):
|
||||||
|
blaze_face_stylizer_256 = ms.SupportedModels.BLAZE_FACE_STYLIZER_256
|
||||||
|
spec = ms.SupportedModels.get(blaze_face_stylizer_256)
|
||||||
|
self.assertIsInstance(spec, ms.ModelSpec)
|
||||||
|
self.assertEqual(spec.style_block_num, 12)
|
||||||
|
self.assertAllEqual(spec.input_image_shape, [256, 256])
|
||||||
|
self.assertEqual(spec.name, 'blaze_face_stylizer_256')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in New Issue
Block a user