Update the face stylizer config to match the latest encoder and detector config.
PiperOrigin-RevId: 527637477
This commit is contained in:
parent
4fd77e38fb
commit
82b8e4d7bf
|
@ -178,7 +178,7 @@ class FaceStylizer(object):
|
|||
with tf.GradientTape() as tape:
|
||||
outputs = self._decoder(
|
||||
{'inputs': in_latent + self.w_avg},
|
||||
training=False,
|
||||
training=True,
|
||||
)
|
||||
gen_img = outputs['image'][-1]
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class HParams(hp.BaseHParams):
|
|||
"""
|
||||
|
||||
# Parameters from BaseHParams class.
|
||||
learning_rate: float = 5e-5
|
||||
learning_rate: float = 8e-4
|
||||
batch_size: int = 4
|
||||
epochs: int = 100
|
||||
# Parameters for face stylizer.
|
||||
|
|
|
@ -21,7 +21,7 @@ from mediapipe.model_maker.python.core.utils import loss_functions
|
|||
|
||||
def _default_perceptual_quality_loss_weight():
|
||||
"""Default perceptual quality loss weight for face stylizer."""
|
||||
return loss_functions.PerceptualLossWeight(l1=2.0, content=20.0, style=10.0)
|
||||
return loss_functions.PerceptualLossWeight(l1=0.5, content=4.0, style=1.0)
|
||||
|
||||
|
||||
# TODO: Add more detailed instructions about hyperparameter tuning.
|
||||
|
@ -32,18 +32,21 @@ class FaceStylizerModelOptions:
|
|||
Attributes:
|
||||
swap_layers: The layers of feature to be interpolated between encoding
|
||||
features and StyleGAN input features.
|
||||
alpha: Weighting coefficient for swapping layer interpolation.
|
||||
alpha: Weighting coefficient of style latent for swapping layer
|
||||
interpolation. Its valid range is [0, 1]. The greater weight means
|
||||
stronger style is applied to the output image. Expect to set it to a small
|
||||
value, i.e. < 0.1.
|
||||
perception_loss_weight: Weighting coefficients of image perception quality
|
||||
loss.
|
||||
adv_loss_weight: Weighting coeffcieint of adversarial loss versus image
|
||||
perceptual quality loss.
|
||||
perceptual quality loss. It expects a small value, i.e. < 0.2.
|
||||
"""
|
||||
|
||||
swap_layers: Sequence[int] = dataclasses.field(
|
||||
default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
|
||||
default_factory=lambda: [4, 5, 10, 11]
|
||||
)
|
||||
alpha: float = 1.0
|
||||
alpha: float = 0.1
|
||||
perception_loss_weight: loss_functions.PerceptualLossWeight = (
|
||||
dataclasses.field(default_factory=_default_perceptual_quality_loss_weight)
|
||||
)
|
||||
adv_loss_weight: float = 1.0
|
||||
adv_loss_weight: float = 0.2
|
||||
|
|
Loading…
Reference in New Issue
Block a user