Switches to tf.keras.optimizers.experimental.AdamW instead of the legacy AdamW.

PiperOrigin-RevId: 496821354
This commit is contained in:
MediaPipe Team 2022-12-20 20:51:23 -08:00 committed by Copybara-Service
parent 151e447614
commit 5c0f548f5f

View File

@ -417,8 +417,22 @@ class _BertClassifier(TextClassifier):
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
warmup_steps = int(total_steps * 0.1) warmup_steps = int(total_steps * 0.1)
initial_lr = self._hparams.learning_rate initial_lr = self._hparams.learning_rate
self._optimizer = optimization.create_optimizer(initial_lr, total_steps, # Implements linear decay of the learning rate.
warmup_steps) lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=initial_lr,
decay_steps=total_steps,
end_learning_rate=0.0,
power=1.0)
if warmup_steps:
lr_schedule = optimization.WarmUp(
initial_learning_rate=initial_lr,
decay_schedule_fn=lr_schedule,
warmup_steps=warmup_steps)
self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"])
def _save_vocab(self, vocab_filepath: str): def _save_vocab(self, vocab_filepath: str):
tf.io.gfile.copy( tf.io.gfile.copy(