Switches to tf.keras.optimizers.experimental.AdamW instead of the legacy AdamW.
PiperOrigin-RevId: 496821354
This commit is contained in:
parent
151e447614
commit
5c0f548f5f
|
@ -417,8 +417,22 @@ class _BertClassifier(TextClassifier):
|
||||||
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
|
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
|
||||||
warmup_steps = int(total_steps * 0.1)
|
warmup_steps = int(total_steps * 0.1)
|
||||||
initial_lr = self._hparams.learning_rate
|
initial_lr = self._hparams.learning_rate
|
||||||
self._optimizer = optimization.create_optimizer(initial_lr, total_steps,
|
# Implements linear decay of the learning rate.
|
||||||
warmup_steps)
|
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||||
|
initial_learning_rate=initial_lr,
|
||||||
|
decay_steps=total_steps,
|
||||||
|
end_learning_rate=0.0,
|
||||||
|
power=1.0)
|
||||||
|
if warmup_steps:
|
||||||
|
lr_schedule = optimization.WarmUp(
|
||||||
|
initial_learning_rate=initial_lr,
|
||||||
|
decay_schedule_fn=lr_schedule,
|
||||||
|
warmup_steps=warmup_steps)
|
||||||
|
|
||||||
|
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
||||||
|
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
|
||||||
|
self._optimizer.exclude_from_weight_decay(
|
||||||
|
var_names=["LayerNorm", "layer_norm", "bias"])
|
||||||
|
|
||||||
def _save_vocab(self, vocab_filepath: str):
|
def _save_vocab(self, vocab_filepath: str):
|
||||||
tf.io.gfile.copy(
|
tf.io.gfile.copy(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user