Change to add the w_avg latent code to style encoding before layer swapping. This is a bug in the previous code. Also set training=True for encoder since this affect the encoding performance.

PiperOrigin-RevId: 553234376
This commit is contained in:
MediaPipe Team 2023-08-02 13:09:01 -07:00 committed by Copybara-Service
parent 6e54d8c204
commit 366a3290cf

View File

@ -146,7 +146,7 @@ class FaceStylizer(object):
batch_size = self._hparams.batch_size batch_size = self._hparams.batch_size
label_in = tf.zeros(shape=[batch_size, 0]) label_in = tf.zeros(shape=[batch_size, 0])
style_encoding = self._encoder(style_img) style_encoding = self._encoder(style_img, training=True) + self.w_avg
optimizer = tf.keras.optimizers.Adam( optimizer = tf.keras.optimizers.Adam(
learning_rate=self._hparams.learning_rate, learning_rate=self._hparams.learning_rate,
@ -176,10 +176,7 @@ class FaceStylizer(object):
) )
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
outputs = self._decoder( outputs = self._decoder({'inputs': in_latent.numpy()}, training=True)
{'inputs': in_latent + self.w_avg},
training=True,
)
gen_img = outputs['image'][-1] gen_img = outputs['image'][-1]
real_feature = self._discriminator( real_feature = self._discriminator(
@ -194,7 +191,7 @@ class FaceStylizer(object):
tf.keras.losses.MeanAbsoluteError()(real_feature, gen_feature) tf.keras.losses.MeanAbsoluteError()(real_feature, gen_feature)
* self._model_options.adv_loss_weight * self._model_options.adv_loss_weight
) )
tf.compat.v1.logging.info(f'Iteration {i} loss: {style_loss.numpy()}') print(f'Iteration {i} loss: {style_loss.numpy()}')
tvars = self._decoder.trainable_variables tvars = self._decoder.trainable_variables
grads = tape.gradient(style_loss, tvars) grads = tape.gradient(style_loss, tvars)