Update the face stylizer config to match the latest encoder and detector config.

PiperOrigin-RevId: 527637477
This commit is contained in:
MediaPipe Team 2023-04-27 11:35:01 -07:00 committed by Copybara-Service
parent 4fd77e38fb
commit 82b8e4d7bf
3 changed files with 11 additions and 8 deletions

View File

@ -178,7 +178,7 @@ class FaceStylizer(object):
with tf.GradientTape() as tape:
outputs = self._decoder(
{'inputs': in_latent + self.w_avg},
training=False,
training=True,
)
gen_img = outputs['image'][-1]

View File

@ -31,7 +31,7 @@ class HParams(hp.BaseHParams):
"""
# Parameters from BaseHParams class.
learning_rate: float = 5e-5
learning_rate: float = 8e-4
batch_size: int = 4
epochs: int = 100
# Parameters for face stylizer.

View File

@ -21,7 +21,7 @@ from mediapipe.model_maker.python.core.utils import loss_functions
def _default_perceptual_quality_loss_weight():
"""Default perceptual quality loss weight for face stylizer."""
return loss_functions.PerceptualLossWeight(l1=2.0, content=20.0, style=10.0)
return loss_functions.PerceptualLossWeight(l1=0.5, content=4.0, style=1.0)
# TODO: Add more detailed instructions about hyperparameter tuning.
@ -32,18 +32,21 @@ class FaceStylizerModelOptions:
Attributes:
swap_layers: The layers of feature to be interpolated between encoding
features and StyleGAN input features.
alpha: Weighting coefficient for swapping layer interpolation.
alpha: Weighting coefficient of style latent for swapping layer
interpolation. Its valid range is [0, 1]. The greater weight means
stronger style is applied to the output image. Expect to set it to a small
value, i.e. < 0.1.
perception_loss_weight: Weighting coefficients of image perception quality
loss.
adv_loss_weight: Weighting coeffcieint of adversarial loss versus image
perceptual quality loss.
perceptual quality loss. It expects a small value, i.e. < 0.2.
"""
swap_layers: Sequence[int] = dataclasses.field(
default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
default_factory=lambda: [4, 5, 10, 11]
)
alpha: float = 1.0
alpha: float = 0.1
perception_loss_weight: loss_functions.PerceptualLossWeight = (
dataclasses.field(default_factory=_default_perceptual_quality_loss_weight)
)
adv_loss_weight: float = 1.0
adv_loss_weight: float = 0.2