Add option to omit the checkpoint callback in text classifier.
PiperOrigin-RevId: 580658724
This commit is contained in:
parent
ae606c1550
commit
252c7eef25
|
@ -35,18 +35,23 @@ ESTIMITED_STEPS_PER_EPOCH = 1000
|
||||||
|
|
||||||
def get_default_callbacks(
|
def get_default_callbacks(
|
||||||
export_dir: str,
|
export_dir: str,
|
||||||
|
checkpoint_frequency: int = 5,
|
||||||
) -> Sequence[tf.keras.callbacks.Callback]:
|
) -> Sequence[tf.keras.callbacks.Callback]:
|
||||||
"""Gets default callbacks."""
|
"""Gets default callbacks."""
|
||||||
|
callbacks = []
|
||||||
summary_dir = os.path.join(export_dir, 'summaries')
|
summary_dir = os.path.join(export_dir, 'summaries')
|
||||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||||
|
callbacks.append(summary_callback)
|
||||||
|
|
||||||
|
if checkpoint_frequency > 0:
|
||||||
checkpoint_path = os.path.join(export_dir, 'checkpoint')
|
checkpoint_path = os.path.join(export_dir, 'checkpoint')
|
||||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||||
save_weights_only=True,
|
save_weights_only=True,
|
||||||
period=5,
|
period=checkpoint_frequency,
|
||||||
)
|
)
|
||||||
return [summary_callback, checkpoint_callback]
|
callbacks.append(checkpoint_callback)
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
def load_keras_model(
|
def load_keras_model(
|
||||||
|
|
|
@ -25,6 +25,23 @@ from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
|
||||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_get_default_callbacks(self):
|
||||||
|
callbacks = model_util.get_default_callbacks(
|
||||||
|
'export_dir', checkpoint_frequency=5
|
||||||
|
)
|
||||||
|
self.assertLen(callbacks, 2)
|
||||||
|
self.assertIsInstance(callbacks[0], tf.keras.callbacks.TensorBoard)
|
||||||
|
self.assertEqual(callbacks[0].log_dir, 'export_dir/summaries')
|
||||||
|
self.assertIsInstance(callbacks[1], tf.keras.callbacks.ModelCheckpoint)
|
||||||
|
self.assertEqual(callbacks[1].period, 5)
|
||||||
|
|
||||||
|
callbacks = model_util.get_default_callbacks(
|
||||||
|
'export_dir_2', checkpoint_frequency=0
|
||||||
|
)
|
||||||
|
self.assertLen(callbacks, 1)
|
||||||
|
self.assertIsInstance(callbacks[0], tf.keras.callbacks.TensorBoard)
|
||||||
|
self.assertEqual(callbacks[0].log_dir, 'export_dir_2/summaries')
|
||||||
|
|
||||||
def test_load_keras_model(self):
|
def test_load_keras_model(self):
|
||||||
input_dim = 4
|
input_dim = 4
|
||||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
|
|
|
@ -56,6 +56,8 @@ class BertHParams(hp.BaseHParams):
|
||||||
value to 0. Defaults to 2.0.
|
value to 0. Defaults to 2.0.
|
||||||
tokenizer: Tokenizer to use for preprocessing. Must be one of the enum
|
tokenizer: Tokenizer to use for preprocessing. Must be one of the enum
|
||||||
options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
|
options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
|
||||||
|
checkpoint_frequency: Frequency(in epochs) of saving checkpoints during
|
||||||
|
training. Defaults to 0 which does not save training checkpoints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
learning_rate: float = 3e-5
|
learning_rate: float = 3e-5
|
||||||
|
@ -75,5 +77,7 @@ class BertHParams(hp.BaseHParams):
|
||||||
bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
|
bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
|
||||||
)
|
)
|
||||||
|
|
||||||
|
checkpoint_frequency: int = 0
|
||||||
|
|
||||||
|
|
||||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
||||||
|
|
|
@ -372,7 +372,9 @@ class _BertClassifier(TextClassifier):
|
||||||
):
|
):
|
||||||
super().__init__(model_spec, label_names, hparams.shuffle)
|
super().__init__(model_spec, label_names, hparams.shuffle)
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
self._callbacks = model_util.get_default_callbacks(
|
||||||
|
self._hparams.export_dir, self._hparams.checkpoint_frequency
|
||||||
|
)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||||
with self._hparams.get_strategy().scope():
|
with self._hparams.get_strategy().scope():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user