Internal change - migration
PiperOrigin-RevId: 486853689
This commit is contained in:
		
							parent
							
								
									0a08e4768b
								
							
						
					
					
						commit
						24d03451c7
					
				|  | @ -19,7 +19,7 @@ from __future__ import print_function | |||
| 
 | ||||
| import os | ||||
| import tempfile | ||||
| from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union | ||||
| from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union | ||||
| 
 | ||||
| # Dependency imports | ||||
| 
 | ||||
|  | @ -34,6 +34,19 @@ DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 | |||
| ESTIMITED_STEPS_PER_EPOCH = 1000 | ||||
| 
 | ||||
| 
 | ||||
| def get_default_callbacks( | ||||
|     export_dir: str) -> Sequence[tf.keras.callbacks.Callback]: | ||||
|   """Gets default callbacks.""" | ||||
|   summary_dir = os.path.join(export_dir, 'summaries') | ||||
|   summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) | ||||
|   # Save checkpoint every 20 epochs. | ||||
| 
 | ||||
|   checkpoint_path = os.path.join(export_dir, 'checkpoint') | ||||
|   checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( | ||||
|       checkpoint_path, save_weights_only=True, period=20) | ||||
|   return [summary_callback, checkpoint_callback] | ||||
| 
 | ||||
| 
 | ||||
| def load_keras_model(model_path: str, | ||||
|                      compile_on_load: bool = False) -> tf.keras.Model: | ||||
|   """Loads a tensorflow Keras model from file and returns the Keras model. | ||||
|  | @ -174,7 +187,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): | |||
|           lambda: self.decay_schedule_fn(step), | ||||
|           name=name) | ||||
| 
 | ||||
|   def get_config(self) -> Dict[Text, Any]: | ||||
|   def get_config(self) -> Dict[str, Any]: | ||||
|     return { | ||||
|         'initial_learning_rate': self.initial_learning_rate, | ||||
|         'decay_schedule_fn': self.decay_schedule_fn, | ||||
|  |  | |||
|  | @ -13,9 +13,6 @@ | |||
| # limitations under the License. | ||||
| """Library to train model.""" | ||||
| 
 | ||||
| import os | ||||
| from typing import List | ||||
| 
 | ||||
| import tensorflow as tf | ||||
| 
 | ||||
| from mediapipe.model_maker.python.core.utils import model_util | ||||
|  | @ -49,19 +46,6 @@ def _create_optimizer(init_lr: float, decay_steps: int, | |||
|   return optimizer | ||||
| 
 | ||||
| 
 | ||||
| def _get_default_callbacks( | ||||
|     export_dir: str) -> List[tf.keras.callbacks.Callback]: | ||||
|   """Gets default callbacks.""" | ||||
|   summary_dir = os.path.join(export_dir, 'summaries') | ||||
|   summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) | ||||
|   # Save checkpoint every 20 epochs. | ||||
| 
 | ||||
|   checkpoint_path = os.path.join(export_dir, 'checkpoint') | ||||
|   checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( | ||||
|       checkpoint_path, save_weights_only=True, period=20) | ||||
|   return [summary_callback, checkpoint_callback] | ||||
| 
 | ||||
| 
 | ||||
| def train_model(model: tf.keras.Model, hparams: hp.HParams, | ||||
|                 train_ds: tf.data.Dataset, | ||||
|                 validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History: | ||||
|  | @ -94,7 +78,7 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams, | |||
|   loss = tf.keras.losses.CategoricalCrossentropy( | ||||
|       label_smoothing=hparams.label_smoothing) | ||||
|   model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) | ||||
|   callbacks = _get_default_callbacks(export_dir=hparams.export_dir) | ||||
|   callbacks = model_util.get_default_callbacks(export_dir=hparams.export_dir) | ||||
| 
 | ||||
|   # Train the model. | ||||
|   return model.fit( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user