diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 2b1eebf9f..32b509797 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -35,18 +35,23 @@ ESTIMITED_STEPS_PER_EPOCH = 1000 def get_default_callbacks( export_dir: str, + checkpoint_frequency: int = 5, ) -> Sequence[tf.keras.callbacks.Callback]: """Gets default callbacks.""" + callbacks = [] summary_dir = os.path.join(export_dir, 'summaries') summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) + callbacks.append(summary_callback) - checkpoint_path = os.path.join(export_dir, 'checkpoint') - checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - os.path.join(checkpoint_path, 'model-{epoch:04d}'), - save_weights_only=True, - period=5, - ) - return [summary_callback, checkpoint_callback] + if checkpoint_frequency > 0: + checkpoint_path = os.path.join(export_dir, 'checkpoint') + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + os.path.join(checkpoint_path, 'model-{epoch:04d}'), + save_weights_only=True, + period=checkpoint_frequency, + ) + callbacks.append(checkpoint_callback) + return callbacks def load_keras_model( diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index 57750624f..ed8ba85e5 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -25,6 +25,23 @@ from mediapipe.model_maker.python.core.utils import test_util class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): + def test_get_default_callbacks(self): + callbacks = model_util.get_default_callbacks( + 'export_dir', checkpoint_frequency=5 + ) + self.assertLen(callbacks, 2) + self.assertIsInstance(callbacks[0], tf.keras.callbacks.TensorBoard) + self.assertEqual(callbacks[0].log_dir, 'export_dir/summaries') + self.assertIsInstance(callbacks[1], tf.keras.callbacks.ModelCheckpoint) + self.assertEqual(callbacks[1].period, 5) + + callbacks = model_util.get_default_callbacks( + 'export_dir_2', checkpoint_frequency=0 + ) + self.assertLen(callbacks, 1) + self.assertIsInstance(callbacks[0], tf.keras.callbacks.TensorBoard) + self.assertEqual(callbacks[0].log_dir, 'export_dir_2/summaries') + def test_load_keras_model(self): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) diff --git a/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py b/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py index 5d16564f5..a7dc05d5b 100644 --- a/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py +++ b/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py @@ -56,6 +56,8 @@ class BertHParams(hp.BaseHParams): value to 0. Defaults to 2.0. tokenizer: Tokenizer to use for preprocessing. Must be one of the enum options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER. + checkpoint_frequency: Frequency(in epochs) of saving checkpoints during + training. Defaults to 0 which does not save training checkpoints. """ learning_rate: float = 3e-5 @@ -75,5 +77,7 @@ class BertHParams(hp.BaseHParams): bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER ) + checkpoint_frequency: int = 0 + HParams = Union[BertHParams, AverageWordEmbeddingHParams] diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 348f4cfb6..386e9360e 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -372,7 +372,9 @@ class _BertClassifier(TextClassifier): ): super().__init__(model_spec, label_names, hparams.shuffle) self._hparams = hparams - self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) + self._callbacks = model_util.get_default_callbacks( + self._hparams.export_dir, self._hparams.checkpoint_frequency + ) self._model_options = model_options self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None with self._hparams.get_strategy().scope():