Switches to tf.keras.optimizers.experimental.AdamW instead of the legacy AdamW.
PiperOrigin-RevId: 496821354
This commit is contained in:
parent
151e447614
commit
5c0f548f5f
|
@ -417,8 +417,22 @@ class _BertClassifier(TextClassifier):
|
|||
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
|
||||
warmup_steps = int(total_steps * 0.1)
|
||||
initial_lr = self._hparams.learning_rate
|
||||
self._optimizer = optimization.create_optimizer(initial_lr, total_steps,
|
||||
warmup_steps)
|
||||
# Implements linear decay of the learning rate.
|
||||
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
initial_learning_rate=initial_lr,
|
||||
decay_steps=total_steps,
|
||||
end_learning_rate=0.0,
|
||||
power=1.0)
|
||||
if warmup_steps:
|
||||
lr_schedule = optimization.WarmUp(
|
||||
initial_learning_rate=initial_lr,
|
||||
decay_schedule_fn=lr_schedule,
|
||||
warmup_steps=warmup_steps)
|
||||
|
||||
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
||||
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
|
||||
self._optimizer.exclude_from_weight_decay(
|
||||
var_names=["LayerNorm", "layer_norm", "bias"])
|
||||
|
||||
def _save_vocab(self, vocab_filepath: str):
|
||||
tf.io.gfile.copy(
|
||||
|
|
Loading…
Reference in New Issue
Block a user