Change to add the w_avg latent code to style encoding before layer swapping. This is a bug in the previous code. Also set training=True for encoder since this affect the encoding performance.
PiperOrigin-RevId: 553234376
This commit is contained in:
parent
6e54d8c204
commit
366a3290cf
|
@ -146,7 +146,7 @@ class FaceStylizer(object):
|
|||
batch_size = self._hparams.batch_size
|
||||
label_in = tf.zeros(shape=[batch_size, 0])
|
||||
|
||||
style_encoding = self._encoder(style_img)
|
||||
style_encoding = self._encoder(style_img, training=True) + self.w_avg
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=self._hparams.learning_rate,
|
||||
|
@ -176,10 +176,7 @@ class FaceStylizer(object):
|
|||
)
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
outputs = self._decoder(
|
||||
{'inputs': in_latent + self.w_avg},
|
||||
training=True,
|
||||
)
|
||||
outputs = self._decoder({'inputs': in_latent.numpy()}, training=True)
|
||||
gen_img = outputs['image'][-1]
|
||||
|
||||
real_feature = self._discriminator(
|
||||
|
@ -194,7 +191,7 @@ class FaceStylizer(object):
|
|||
tf.keras.losses.MeanAbsoluteError()(real_feature, gen_feature)
|
||||
* self._model_options.adv_loss_weight
|
||||
)
|
||||
tf.compat.v1.logging.info(f'Iteration {i} loss: {style_loss.numpy()}')
|
||||
print(f'Iteration {i} loss: {style_loss.numpy()}')
|
||||
|
||||
tvars = self._decoder.trainable_variables
|
||||
grads = tape.gradient(style_loss, tvars)
|
||||
|
|
Loading…
Reference in New Issue
Block a user