Add option to omit the checkpoint callback in text classifier.
PiperOrigin-RevId: 580658724
This commit is contained in:
		
							parent
							
								
									ae606c1550
								
							
						
					
					
						commit
						252c7eef25
					
				| 
						 | 
					@ -35,18 +35,23 @@ ESTIMITED_STEPS_PER_EPOCH = 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_default_callbacks(
 | 
					def get_default_callbacks(
 | 
				
			||||||
    export_dir: str,
 | 
					    export_dir: str,
 | 
				
			||||||
 | 
					    checkpoint_frequency: int = 5,
 | 
				
			||||||
) -> Sequence[tf.keras.callbacks.Callback]:
 | 
					) -> Sequence[tf.keras.callbacks.Callback]:
 | 
				
			||||||
  """Gets default callbacks."""
 | 
					  """Gets default callbacks."""
 | 
				
			||||||
 | 
					  callbacks = []
 | 
				
			||||||
  summary_dir = os.path.join(export_dir, 'summaries')
 | 
					  summary_dir = os.path.join(export_dir, 'summaries')
 | 
				
			||||||
  summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
 | 
					  summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
 | 
				
			||||||
 | 
					  callbacks.append(summary_callback)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  checkpoint_path = os.path.join(export_dir, 'checkpoint')
 | 
					  if checkpoint_frequency > 0:
 | 
				
			||||||
  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
 | 
					    checkpoint_path = os.path.join(export_dir, 'checkpoint')
 | 
				
			||||||
      os.path.join(checkpoint_path, 'model-{epoch:04d}'),
 | 
					    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
 | 
				
			||||||
      save_weights_only=True,
 | 
					        os.path.join(checkpoint_path, 'model-{epoch:04d}'),
 | 
				
			||||||
      period=5,
 | 
					        save_weights_only=True,
 | 
				
			||||||
  )
 | 
					        period=checkpoint_frequency,
 | 
				
			||||||
  return [summary_callback, checkpoint_callback]
 | 
					    )
 | 
				
			||||||
 | 
					    callbacks.append(checkpoint_callback)
 | 
				
			||||||
 | 
					  return callbacks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load_keras_model(
 | 
					def load_keras_model(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,6 +25,23 @@ from mediapipe.model_maker.python.core.utils import test_util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
 | 
					class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_get_default_callbacks(self):
 | 
				
			||||||
 | 
					    callbacks = model_util.get_default_callbacks(
 | 
				
			||||||
 | 
					        'export_dir', checkpoint_frequency=5
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self.assertLen(callbacks, 2)
 | 
				
			||||||
 | 
					    self.assertIsInstance(callbacks[0], tf.keras.callbacks.TensorBoard)
 | 
				
			||||||
 | 
					    self.assertEqual(callbacks[0].log_dir, 'export_dir/summaries')
 | 
				
			||||||
 | 
					    self.assertIsInstance(callbacks[1], tf.keras.callbacks.ModelCheckpoint)
 | 
				
			||||||
 | 
					    self.assertEqual(callbacks[1].period, 5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    callbacks = model_util.get_default_callbacks(
 | 
				
			||||||
 | 
					        'export_dir_2', checkpoint_frequency=0
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self.assertLen(callbacks, 1)
 | 
				
			||||||
 | 
					    self.assertIsInstance(callbacks[0], tf.keras.callbacks.TensorBoard)
 | 
				
			||||||
 | 
					    self.assertEqual(callbacks[0].log_dir, 'export_dir_2/summaries')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_load_keras_model(self):
 | 
					  def test_load_keras_model(self):
 | 
				
			||||||
    input_dim = 4
 | 
					    input_dim = 4
 | 
				
			||||||
    model = test_util.build_model(input_shape=[input_dim], num_classes=2)
 | 
					    model = test_util.build_model(input_shape=[input_dim], num_classes=2)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -56,6 +56,8 @@ class BertHParams(hp.BaseHParams):
 | 
				
			||||||
      value to 0. Defaults to 2.0.
 | 
					      value to 0. Defaults to 2.0.
 | 
				
			||||||
    tokenizer: Tokenizer to use for preprocessing. Must be one of the enum
 | 
					    tokenizer: Tokenizer to use for preprocessing. Must be one of the enum
 | 
				
			||||||
      options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
 | 
					      options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
 | 
				
			||||||
 | 
					    checkpoint_frequency: Frequency(in epochs) of saving checkpoints during
 | 
				
			||||||
 | 
					      training. Defaults to 0 which does not save training checkpoints.
 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  learning_rate: float = 3e-5
 | 
					  learning_rate: float = 3e-5
 | 
				
			||||||
| 
						 | 
					@ -75,5 +77,7 @@ class BertHParams(hp.BaseHParams):
 | 
				
			||||||
      bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
 | 
					      bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  checkpoint_frequency: int = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
 | 
					HParams = Union[BertHParams, AverageWordEmbeddingHParams]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -372,7 +372,9 @@ class _BertClassifier(TextClassifier):
 | 
				
			||||||
  ):
 | 
					  ):
 | 
				
			||||||
    super().__init__(model_spec, label_names, hparams.shuffle)
 | 
					    super().__init__(model_spec, label_names, hparams.shuffle)
 | 
				
			||||||
    self._hparams = hparams
 | 
					    self._hparams = hparams
 | 
				
			||||||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
					    self._callbacks = model_util.get_default_callbacks(
 | 
				
			||||||
 | 
					        self._hparams.export_dir, self._hparams.checkpoint_frequency
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    self._model_options = model_options
 | 
					    self._model_options = model_options
 | 
				
			||||||
    self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
 | 
					    self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
 | 
				
			||||||
    with self._hparams.get_strategy().scope():
 | 
					    with self._hparams.get_strategy().scope():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user