Update the face stylizer config to match the latest encoder and detector config.
PiperOrigin-RevId: 527637477
This commit is contained in:
		
							parent
							
								
									4fd77e38fb
								
							
						
					
					
						commit
						82b8e4d7bf
					
				| 
						 | 
					@ -178,7 +178,7 @@ class FaceStylizer(object):
 | 
				
			||||||
      with tf.GradientTape() as tape:
 | 
					      with tf.GradientTape() as tape:
 | 
				
			||||||
        outputs = self._decoder(
 | 
					        outputs = self._decoder(
 | 
				
			||||||
            {'inputs': in_latent + self.w_avg},
 | 
					            {'inputs': in_latent + self.w_avg},
 | 
				
			||||||
            training=False,
 | 
					            training=True,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        gen_img = outputs['image'][-1]
 | 
					        gen_img = outputs['image'][-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,7 +31,7 @@ class HParams(hp.BaseHParams):
 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Parameters from BaseHParams class.
 | 
					  # Parameters from BaseHParams class.
 | 
				
			||||||
  learning_rate: float = 5e-5
 | 
					  learning_rate: float = 8e-4
 | 
				
			||||||
  batch_size: int = 4
 | 
					  batch_size: int = 4
 | 
				
			||||||
  epochs: int = 100
 | 
					  epochs: int = 100
 | 
				
			||||||
  # Parameters for face stylizer.
 | 
					  # Parameters for face stylizer.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,7 +21,7 @@ from mediapipe.model_maker.python.core.utils import loss_functions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _default_perceptual_quality_loss_weight():
 | 
					def _default_perceptual_quality_loss_weight():
 | 
				
			||||||
  """Default perceptual quality loss weight for face stylizer."""
 | 
					  """Default perceptual quality loss weight for face stylizer."""
 | 
				
			||||||
  return loss_functions.PerceptualLossWeight(l1=2.0, content=20.0, style=10.0)
 | 
					  return loss_functions.PerceptualLossWeight(l1=0.5, content=4.0, style=1.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: Add more detailed instructions about hyperparameter tuning.
 | 
					# TODO: Add more detailed instructions about hyperparameter tuning.
 | 
				
			||||||
| 
						 | 
					@ -32,18 +32,21 @@ class FaceStylizerModelOptions:
 | 
				
			||||||
  Attributes:
 | 
					  Attributes:
 | 
				
			||||||
    swap_layers: The layers of feature to be interpolated between encoding
 | 
					    swap_layers: The layers of feature to be interpolated between encoding
 | 
				
			||||||
      features and StyleGAN input features.
 | 
					      features and StyleGAN input features.
 | 
				
			||||||
    alpha: Weighting coefficient for swapping layer interpolation.
 | 
					    alpha: Weighting coefficient of style latent for swapping layer
 | 
				
			||||||
 | 
					      interpolation. Its valid range is [0, 1]. The greater weight means
 | 
				
			||||||
 | 
					      stronger style is applied to the output image. Expect to set it to a small
 | 
				
			||||||
 | 
					      value, i.e. < 0.1.
 | 
				
			||||||
    perception_loss_weight: Weighting coefficients of image perception quality
 | 
					    perception_loss_weight: Weighting coefficients of image perception quality
 | 
				
			||||||
      loss.
 | 
					      loss.
 | 
				
			||||||
    adv_loss_weight: Weighting coeffcieint of adversarial loss versus image
 | 
					    adv_loss_weight: Weighting coeffcieint of adversarial loss versus image
 | 
				
			||||||
      perceptual quality loss.
 | 
					      perceptual quality loss. It expects a small value, i.e. < 0.2.
 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  swap_layers: Sequence[int] = dataclasses.field(
 | 
					  swap_layers: Sequence[int] = dataclasses.field(
 | 
				
			||||||
      default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
 | 
					      default_factory=lambda: [4, 5, 10, 11]
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  alpha: float = 1.0
 | 
					  alpha: float = 0.1
 | 
				
			||||||
  perception_loss_weight: loss_functions.PerceptualLossWeight = (
 | 
					  perception_loss_weight: loss_functions.PerceptualLossWeight = (
 | 
				
			||||||
      dataclasses.field(default_factory=_default_perceptual_quality_loss_weight)
 | 
					      dataclasses.field(default_factory=_default_perceptual_quality_loss_weight)
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  adv_loss_weight: float = 1.0
 | 
					  adv_loss_weight: float = 0.2
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user