From 5c0f548f5f5b31d94b749456cdac306b5330dfa3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 20:51:23 -0800 Subject: [PATCH] Switches to tf.keras.optimizers.experimental.AdamW instead of the legacy AdamW. PiperOrigin-RevId: 496821354 --- .../text/text_classifier/text_classifier.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index c285702d2..c4d3fdbe2 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -417,8 +417,22 @@ class _BertClassifier(TextClassifier): total_steps = self._hparams.steps_per_epoch * self._hparams.epochs warmup_steps = int(total_steps * 0.1) initial_lr = self._hparams.learning_rate - self._optimizer = optimization.create_optimizer(initial_lr, total_steps, - warmup_steps) + # Implements linear decay of the learning rate. + lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=initial_lr, + decay_steps=total_steps, + end_learning_rate=0.0, + power=1.0) + if warmup_steps: + lr_schedule = optimization.WarmUp( + initial_learning_rate=initial_lr, + decay_schedule_fn=lr_schedule, + warmup_steps=warmup_steps) + + self._optimizer = tf.keras.optimizers.experimental.AdamW( + lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0) + self._optimizer.exclude_from_weight_decay( + var_names=["LayerNorm", "layer_norm", "bias"]) def _save_vocab(self, vocab_filepath: str): tf.io.gfile.copy(