Add the model configuration and training hyperparameters for BlazeFaceStylizer.
PiperOrigin-RevId: 520767282
This commit is contained in:
parent
a4923ca7aa
commit
1eeb89e95f
|
@ -26,6 +26,40 @@ filegroup(
|
|||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_options",
|
||||
srcs = ["model_options.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_spec",
|
||||
srcs = ["model_spec.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_spec_test",
|
||||
srcs = ["model_spec_test.py"],
|
||||
deps = [":model_spec"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "face_stylizer_options",
|
||||
srcs = ["face_stylizer_options.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Options for building face stylizer."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.face_stylizer import model_spec
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FaceStylizerOptions:
|
||||
"""Configurable options for building face stylizer.
|
||||
|
||||
Attributes:
|
||||
model: A model from the SupportedModels enum.
|
||||
hparams: A set of hyperparameters used to train the face stylizer.
|
||||
model_options: A set of options for configuring the model.
|
||||
"""
|
||||
|
||||
model: model_spec.SupportedModels
|
||||
model_options: Optional[model_opt.FaceStylizerModelOptions] = None
|
||||
hparams: Optional[hyperparameters.HParams] = None
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Hyperparameters for training on-device face stylization models."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HParams(hp.BaseHParams):
|
||||
"""The hyperparameters for training face stylizers.
|
||||
|
||||
Attributes:
|
||||
learning_rate: Learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations.
|
||||
beta_1: beta_1 used in tf.keras.optimizers.Adam.
|
||||
beta_2: beta_2 used in tf.keras.optimizers.Adam.
|
||||
"""
|
||||
|
||||
# Parameters from BaseHParams class.
|
||||
learning_rate: float = 5e-5
|
||||
batch_size: int = 4
|
||||
epochs: int = 100
|
||||
# Parameters for face stylizer.
|
||||
beta_1 = 0.0
|
||||
beta_2 = 0.99
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Configurable model options for face stylizer models."""
|
||||
|
||||
import dataclasses
|
||||
from typing import List
|
||||
|
||||
|
||||
# TODO: Add more detailed instructions about hyperparameter tuning.
|
||||
@dataclasses.dataclass
|
||||
class FaceStylizerModelOptions:
|
||||
"""Configurable model options for face stylizer models.
|
||||
|
||||
Attributes:
|
||||
swap_layers: The layers of feature to be interpolated between encoding
|
||||
features and StyleGAN input features.
|
||||
alpha: Weighting coefficient for swapping layer interpolation.
|
||||
adv_loss_weight: Weighting coeffcieint of adversarial loss versus perceptual
|
||||
loss.
|
||||
"""
|
||||
|
||||
swap_layers: List[int] = dataclasses.field(
|
||||
default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
|
||||
)
|
||||
alpha: float = 1.0
|
||||
adv_loss_weight: float = 1.0
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Face stylizer model specification."""
|
||||
|
||||
import enum
|
||||
import functools
|
||||
from typing import List
|
||||
|
||||
|
||||
class ModelSpec(object):
|
||||
"""Specification of face stylizer model."""
|
||||
|
||||
mean_rgb = [127.5]
|
||||
stddev_rgb = [127.5]
|
||||
|
||||
def __init__(
|
||||
self, style_block_num: int, input_image_shape: List[int], name: str = ''
|
||||
):
|
||||
"""Initializes a new instance of the `ModelSpec` class for face stylizer.
|
||||
|
||||
Args:
|
||||
style_block_num: int, number of style block in the decoder.
|
||||
input_image_shape: list of int, input image shape.
|
||||
name: str, model spec name.
|
||||
"""
|
||||
self.style_block_num = style_block_num
|
||||
self.input_image_shape = input_image_shape
|
||||
self.name = name
|
||||
|
||||
|
||||
blaze_face_stylizer_256_spec = functools.partial(
|
||||
ModelSpec,
|
||||
style_block_num=12,
|
||||
input_image_shape=[256, 256],
|
||||
name='blaze_face_stylizer_256',
|
||||
)
|
||||
|
||||
|
||||
# TODO: Document the exposed models.
|
||||
@enum.unique
|
||||
class SupportedModels(enum.Enum):
|
||||
"""Face stylizer model supported by MediaPipe model maker."""
|
||||
|
||||
BLAZE_FACE_STYLIZER_256 = blaze_face_stylizer_256_spec
|
||||
|
||||
@classmethod
|
||||
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||
"""Gets model spec from the input enum and initializes it."""
|
||||
if spec not in cls:
|
||||
raise TypeError('Unsupported face stylizer spec: {}'.format(spec))
|
||||
|
||||
return spec.value()
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
|
||||
|
||||
|
||||
class ModelSpecTest(tf.test.TestCase):
|
||||
|
||||
def test_predefine_spec(self):
|
||||
blaze_face_stylizer_256_spec = ms.blaze_face_stylizer_256_spec()
|
||||
self.assertIsInstance(blaze_face_stylizer_256_spec, ms.ModelSpec)
|
||||
self.assertEqual(blaze_face_stylizer_256_spec.style_block_num, 12)
|
||||
self.assertAllEqual(
|
||||
blaze_face_stylizer_256_spec.input_image_shape, [256, 256]
|
||||
)
|
||||
self.assertEqual(
|
||||
blaze_face_stylizer_256_spec.name, 'blaze_face_stylizer_256'
|
||||
)
|
||||
|
||||
def test_predefine_spec_enum(self):
|
||||
blaze_face_stylizer_256 = ms.SupportedModels.BLAZE_FACE_STYLIZER_256
|
||||
spec = ms.SupportedModels.get(blaze_face_stylizer_256)
|
||||
self.assertIsInstance(spec, ms.ModelSpec)
|
||||
self.assertEqual(spec.style_block_num, 12)
|
||||
self.assertAllEqual(spec.input_image_shape, [256, 256])
|
||||
self.assertEqual(spec.name, 'blaze_face_stylizer_256')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in New Issue
Block a user