Update the face stylizer config to match the latest encoder and detector config.
PiperOrigin-RevId: 527637477
This commit is contained in:
parent
4fd77e38fb
commit
82b8e4d7bf
|
@ -178,7 +178,7 @@ class FaceStylizer(object):
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
outputs = self._decoder(
|
outputs = self._decoder(
|
||||||
{'inputs': in_latent + self.w_avg},
|
{'inputs': in_latent + self.w_avg},
|
||||||
training=False,
|
training=True,
|
||||||
)
|
)
|
||||||
gen_img = outputs['image'][-1]
|
gen_img = outputs['image'][-1]
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ class HParams(hp.BaseHParams):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Parameters from BaseHParams class.
|
# Parameters from BaseHParams class.
|
||||||
learning_rate: float = 5e-5
|
learning_rate: float = 8e-4
|
||||||
batch_size: int = 4
|
batch_size: int = 4
|
||||||
epochs: int = 100
|
epochs: int = 100
|
||||||
# Parameters for face stylizer.
|
# Parameters for face stylizer.
|
||||||
|
|
|
@ -21,7 +21,7 @@ from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
|
|
||||||
def _default_perceptual_quality_loss_weight():
|
def _default_perceptual_quality_loss_weight():
|
||||||
"""Default perceptual quality loss weight for face stylizer."""
|
"""Default perceptual quality loss weight for face stylizer."""
|
||||||
return loss_functions.PerceptualLossWeight(l1=2.0, content=20.0, style=10.0)
|
return loss_functions.PerceptualLossWeight(l1=0.5, content=4.0, style=1.0)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Add more detailed instructions about hyperparameter tuning.
|
# TODO: Add more detailed instructions about hyperparameter tuning.
|
||||||
|
@ -32,18 +32,21 @@ class FaceStylizerModelOptions:
|
||||||
Attributes:
|
Attributes:
|
||||||
swap_layers: The layers of feature to be interpolated between encoding
|
swap_layers: The layers of feature to be interpolated between encoding
|
||||||
features and StyleGAN input features.
|
features and StyleGAN input features.
|
||||||
alpha: Weighting coefficient for swapping layer interpolation.
|
alpha: Weighting coefficient of style latent for swapping layer
|
||||||
|
interpolation. Its valid range is [0, 1]. The greater weight means
|
||||||
|
stronger style is applied to the output image. Expect to set it to a small
|
||||||
|
value, i.e. < 0.1.
|
||||||
perception_loss_weight: Weighting coefficients of image perception quality
|
perception_loss_weight: Weighting coefficients of image perception quality
|
||||||
loss.
|
loss.
|
||||||
adv_loss_weight: Weighting coeffcieint of adversarial loss versus image
|
adv_loss_weight: Weighting coeffcieint of adversarial loss versus image
|
||||||
perceptual quality loss.
|
perceptual quality loss. It expects a small value, i.e. < 0.2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
swap_layers: Sequence[int] = dataclasses.field(
|
swap_layers: Sequence[int] = dataclasses.field(
|
||||||
default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
|
default_factory=lambda: [4, 5, 10, 11]
|
||||||
)
|
)
|
||||||
alpha: float = 1.0
|
alpha: float = 0.1
|
||||||
perception_loss_weight: loss_functions.PerceptualLossWeight = (
|
perception_loss_weight: loss_functions.PerceptualLossWeight = (
|
||||||
dataclasses.field(default_factory=_default_perceptual_quality_loss_weight)
|
dataclasses.field(default_factory=_default_perceptual_quality_loss_weight)
|
||||||
)
|
)
|
||||||
adv_loss_weight: float = 1.0
|
adv_loss_weight: float = 0.2
|
||||||
|
|
Loading…
Reference in New Issue
Block a user