Add option to omit the checkpoint callback in text classifier.
PiperOrigin-RevId: 580658724
This commit is contained in:
parent
ae606c1550
commit
252c7eef25
|
@ -35,18 +35,23 @@ ESTIMITED_STEPS_PER_EPOCH = 1000
|
|||
|
||||
def get_default_callbacks(
|
||||
export_dir: str,
|
||||
checkpoint_frequency: int = 5,
|
||||
) -> Sequence[tf.keras.callbacks.Callback]:
|
||||
"""Gets default callbacks."""
|
||||
callbacks = []
|
||||
summary_dir = os.path.join(export_dir, 'summaries')
|
||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||
callbacks.append(summary_callback)
|
||||
|
||||
if checkpoint_frequency > 0:
|
||||
checkpoint_path = os.path.join(export_dir, 'checkpoint')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||
save_weights_only=True,
|
||||
period=5,
|
||||
period=checkpoint_frequency,
|
||||
)
|
||||
return [summary_callback, checkpoint_callback]
|
||||
callbacks.append(checkpoint_callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
def load_keras_model(
|
||||
|
|
|
@ -25,6 +25,23 @@ from mediapipe.model_maker.python.core.utils import test_util
|
|||
|
||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_get_default_callbacks(self):
|
||||
callbacks = model_util.get_default_callbacks(
|
||||
'export_dir', checkpoint_frequency=5
|
||||
)
|
||||
self.assertLen(callbacks, 2)
|
||||
self.assertIsInstance(callbacks[0], tf.keras.callbacks.TensorBoard)
|
||||
self.assertEqual(callbacks[0].log_dir, 'export_dir/summaries')
|
||||
self.assertIsInstance(callbacks[1], tf.keras.callbacks.ModelCheckpoint)
|
||||
self.assertEqual(callbacks[1].period, 5)
|
||||
|
||||
callbacks = model_util.get_default_callbacks(
|
||||
'export_dir_2', checkpoint_frequency=0
|
||||
)
|
||||
self.assertLen(callbacks, 1)
|
||||
self.assertIsInstance(callbacks[0], tf.keras.callbacks.TensorBoard)
|
||||
self.assertEqual(callbacks[0].log_dir, 'export_dir_2/summaries')
|
||||
|
||||
def test_load_keras_model(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
|
|
|
@ -56,6 +56,8 @@ class BertHParams(hp.BaseHParams):
|
|||
value to 0. Defaults to 2.0.
|
||||
tokenizer: Tokenizer to use for preprocessing. Must be one of the enum
|
||||
options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
|
||||
checkpoint_frequency: Frequency(in epochs) of saving checkpoints during
|
||||
training. Defaults to 0 which does not save training checkpoints.
|
||||
"""
|
||||
|
||||
learning_rate: float = 3e-5
|
||||
|
@ -75,5 +77,7 @@ class BertHParams(hp.BaseHParams):
|
|||
bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
|
||||
)
|
||||
|
||||
checkpoint_frequency: int = 0
|
||||
|
||||
|
||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
||||
|
|
|
@ -372,7 +372,9 @@ class _BertClassifier(TextClassifier):
|
|||
):
|
||||
super().__init__(model_spec, label_names, hparams.shuffle)
|
||||
self._hparams = hparams
|
||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||
self._callbacks = model_util.get_default_callbacks(
|
||||
self._hparams.export_dir, self._hparams.checkpoint_frequency
|
||||
)
|
||||
self._model_options = model_options
|
||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||
with self._hparams.get_strategy().scope():
|
||||
|
|
Loading…
Reference in New Issue
Block a user