Add option to omit the checkpoint callback in text classifier.

PiperOrigin-RevId: 580658724
This commit is contained in:
MediaPipe Team 2023-11-08 14:27:05 -08:00 committed by Copybara-Service
parent ae606c1550
commit 252c7eef25
4 changed files with 36 additions and 8 deletions

View File

@ -35,18 +35,23 @@ ESTIMITED_STEPS_PER_EPOCH = 1000
def get_default_callbacks( def get_default_callbacks(
export_dir: str, export_dir: str,
checkpoint_frequency: int = 5,
) -> Sequence[tf.keras.callbacks.Callback]: ) -> Sequence[tf.keras.callbacks.Callback]:
"""Gets default callbacks.""" """Gets default callbacks."""
callbacks = []
summary_dir = os.path.join(export_dir, 'summaries') summary_dir = os.path.join(export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
callbacks.append(summary_callback)
if checkpoint_frequency > 0:
checkpoint_path = os.path.join(export_dir, 'checkpoint') checkpoint_path = os.path.join(export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(checkpoint_path, 'model-{epoch:04d}'), os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True, save_weights_only=True,
period=5, period=checkpoint_frequency,
) )
return [summary_callback, checkpoint_callback] callbacks.append(checkpoint_callback)
return callbacks
def load_keras_model( def load_keras_model(

View File

@ -25,6 +25,23 @@ from mediapipe.model_maker.python.core.utils import test_util
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
def test_get_default_callbacks(self):
callbacks = model_util.get_default_callbacks(
'export_dir', checkpoint_frequency=5
)
self.assertLen(callbacks, 2)
self.assertIsInstance(callbacks[0], tf.keras.callbacks.TensorBoard)
self.assertEqual(callbacks[0].log_dir, 'export_dir/summaries')
self.assertIsInstance(callbacks[1], tf.keras.callbacks.ModelCheckpoint)
self.assertEqual(callbacks[1].period, 5)
callbacks = model_util.get_default_callbacks(
'export_dir_2', checkpoint_frequency=0
)
self.assertLen(callbacks, 1)
self.assertIsInstance(callbacks[0], tf.keras.callbacks.TensorBoard)
self.assertEqual(callbacks[0].log_dir, 'export_dir_2/summaries')
def test_load_keras_model(self): def test_load_keras_model(self):
input_dim = 4 input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2) model = test_util.build_model(input_shape=[input_dim], num_classes=2)

View File

@ -56,6 +56,8 @@ class BertHParams(hp.BaseHParams):
value to 0. Defaults to 2.0. value to 0. Defaults to 2.0.
tokenizer: Tokenizer to use for preprocessing. Must be one of the enum tokenizer: Tokenizer to use for preprocessing. Must be one of the enum
options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER. options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
checkpoint_frequency: Frequency(in epochs) of saving checkpoints during
training. Defaults to 0 which does not save training checkpoints.
""" """
learning_rate: float = 3e-5 learning_rate: float = 3e-5
@ -75,5 +77,7 @@ class BertHParams(hp.BaseHParams):
bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
) )
checkpoint_frequency: int = 0
HParams = Union[BertHParams, AverageWordEmbeddingHParams] HParams = Union[BertHParams, AverageWordEmbeddingHParams]

View File

@ -372,7 +372,9 @@ class _BertClassifier(TextClassifier):
): ):
super().__init__(model_spec, label_names, hparams.shuffle) super().__init__(model_spec, label_names, hparams.shuffle)
self._hparams = hparams self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._callbacks = model_util.get_default_callbacks(
self._hparams.export_dir, self._hparams.checkpoint_frequency
)
self._model_options = model_options self._model_options = model_options
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
with self._hparams.get_strategy().scope(): with self._hparams.get_strategy().scope():