Update the face stylizer config to match the latest encoder and detector config.

PiperOrigin-RevId: 527637477
This commit is contained in:
MediaPipe Team 2023-04-27 11:35:01 -07:00 committed by Copybara-Service
parent 4fd77e38fb
commit 82b8e4d7bf
3 changed files with 11 additions and 8 deletions

View File

@ -178,7 +178,7 @@ class FaceStylizer(object):
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
outputs = self._decoder( outputs = self._decoder(
{'inputs': in_latent + self.w_avg}, {'inputs': in_latent + self.w_avg},
training=False, training=True,
) )
gen_img = outputs['image'][-1] gen_img = outputs['image'][-1]

View File

@ -31,7 +31,7 @@ class HParams(hp.BaseHParams):
""" """
# Parameters from BaseHParams class. # Parameters from BaseHParams class.
learning_rate: float = 5e-5 learning_rate: float = 8e-4
batch_size: int = 4 batch_size: int = 4
epochs: int = 100 epochs: int = 100
# Parameters for face stylizer. # Parameters for face stylizer.

View File

@ -21,7 +21,7 @@ from mediapipe.model_maker.python.core.utils import loss_functions
def _default_perceptual_quality_loss_weight(): def _default_perceptual_quality_loss_weight():
"""Default perceptual quality loss weight for face stylizer.""" """Default perceptual quality loss weight for face stylizer."""
return loss_functions.PerceptualLossWeight(l1=2.0, content=20.0, style=10.0) return loss_functions.PerceptualLossWeight(l1=0.5, content=4.0, style=1.0)
# TODO: Add more detailed instructions about hyperparameter tuning. # TODO: Add more detailed instructions about hyperparameter tuning.
@ -32,18 +32,21 @@ class FaceStylizerModelOptions:
Attributes: Attributes:
swap_layers: The layers of feature to be interpolated between encoding swap_layers: The layers of feature to be interpolated between encoding
features and StyleGAN input features. features and StyleGAN input features.
alpha: Weighting coefficient for swapping layer interpolation. alpha: Weighting coefficient of style latent for swapping layer
interpolation. Its valid range is [0, 1]. The greater weight means
stronger style is applied to the output image. Expect to set it to a small
value, i.e. < 0.1.
perception_loss_weight: Weighting coefficients of image perception quality perception_loss_weight: Weighting coefficients of image perception quality
loss. loss.
adv_loss_weight: Weighting coeffcieint of adversarial loss versus image adv_loss_weight: Weighting coeffcieint of adversarial loss versus image
perceptual quality loss. perceptual quality loss. It expects a small value, i.e. < 0.2.
""" """
swap_layers: Sequence[int] = dataclasses.field( swap_layers: Sequence[int] = dataclasses.field(
default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11] default_factory=lambda: [4, 5, 10, 11]
) )
alpha: float = 1.0 alpha: float = 0.1
perception_loss_weight: loss_functions.PerceptualLossWeight = ( perception_loss_weight: loss_functions.PerceptualLossWeight = (
dataclasses.field(default_factory=_default_perceptual_quality_loss_weight) dataclasses.field(default_factory=_default_perceptual_quality_loss_weight)
) )
adv_loss_weight: float = 1.0 adv_loss_weight: float = 0.2