Internal Model Maker change.
PiperOrigin-RevId: 500758488
This commit is contained in:
		
							parent
							
								
									73f4636292
								
							
						
					
					
						commit
						d40fa6b16d
					
				| 
						 | 
					@ -48,11 +48,12 @@ class Classifier(custom_model.CustomModel):
 | 
				
			||||||
    self._hparams: hp.BaseHParams = None
 | 
					    self._hparams: hp.BaseHParams = None
 | 
				
			||||||
    self._history: tf.keras.callbacks.History = None
 | 
					    self._history: tf.keras.callbacks.History = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # TODO: Integrate this into all Model Maker tasks.
 | 
					  # TODO: Integrate this into GestureRecognizer.
 | 
				
			||||||
  def _train_model(self,
 | 
					  def _train_model(self,
 | 
				
			||||||
                   train_data: classification_ds.ClassificationDataset,
 | 
					                   train_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
                   validation_data: classification_ds.ClassificationDataset,
 | 
					                   validation_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
                   preprocessor: Optional[Callable[..., bool]] = None):
 | 
					                   preprocessor: Optional[Callable[..., bool]] = None,
 | 
				
			||||||
 | 
					                   checkpoint_path: Optional[str] = None):
 | 
				
			||||||
    """Trains the classifier model.
 | 
					    """Trains the classifier model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Compiles and fits the tf.keras `_model` and records the `_history`.
 | 
					    Compiles and fits the tf.keras `_model` and records the `_history`.
 | 
				
			||||||
| 
						 | 
					@ -62,6 +63,9 @@ class Classifier(custom_model.CustomModel):
 | 
				
			||||||
      validation_data: Validation data.
 | 
					      validation_data: Validation data.
 | 
				
			||||||
      preprocessor: An optional data preprocessor that can be used when
 | 
					      preprocessor: An optional data preprocessor that can be used when
 | 
				
			||||||
        generating a tf.data.Dataset.
 | 
					        generating a tf.data.Dataset.
 | 
				
			||||||
 | 
					      checkpoint_path: An optional directory for the checkpoint file to support
 | 
				
			||||||
 | 
					        continual training. If provided, loads model weights from the latest
 | 
				
			||||||
 | 
					        checkpoint in the directory.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    tf.compat.v1.logging.info('Training the models...')
 | 
					    tf.compat.v1.logging.info('Training the models...')
 | 
				
			||||||
    if len(train_data) < self._hparams.batch_size:
 | 
					    if len(train_data) < self._hparams.batch_size:
 | 
				
			||||||
| 
						 | 
					@ -88,6 +92,14 @@ class Classifier(custom_model.CustomModel):
 | 
				
			||||||
        optimizer=self._optimizer,
 | 
					        optimizer=self._optimizer,
 | 
				
			||||||
        loss=self._loss_function,
 | 
					        loss=self._loss_function,
 | 
				
			||||||
        metrics=[self._metric_function])
 | 
					        metrics=[self._metric_function])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    latest_checkpoint = (
 | 
				
			||||||
 | 
					        tf.train.latest_checkpoint(checkpoint_path)
 | 
				
			||||||
 | 
					        if checkpoint_path else None)
 | 
				
			||||||
 | 
					    if latest_checkpoint:
 | 
				
			||||||
 | 
					      print(f'Resuming from {latest_checkpoint}')
 | 
				
			||||||
 | 
					      self._model.load_weights(latest_checkpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    self._history = self._model.fit(
 | 
					    self._history = self._model.fit(
 | 
				
			||||||
        x=train_dataset,
 | 
					        x=train_dataset,
 | 
				
			||||||
        epochs=self._hparams.epochs,
 | 
					        epochs=self._hparams.epochs,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -42,7 +42,9 @@ def get_default_callbacks(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  checkpoint_path = os.path.join(export_dir, 'checkpoint')
 | 
					  checkpoint_path = os.path.join(export_dir, 'checkpoint')
 | 
				
			||||||
  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
 | 
					  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
 | 
				
			||||||
      checkpoint_path, save_weights_only=True)
 | 
					      os.path.join(checkpoint_path, 'model-{epoch:04d}'),
 | 
				
			||||||
 | 
					      save_weights_only=True,
 | 
				
			||||||
 | 
					      period=5)
 | 
				
			||||||
  return [summary_callback, checkpoint_callback]
 | 
					  return [summary_callback, checkpoint_callback]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -87,15 +87,6 @@ py_library(
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_library(
 | 
					 | 
				
			||||||
    name = "train_image_classifier_lib",
 | 
					 | 
				
			||||||
    srcs = ["train_image_classifier_lib.py"],
 | 
					 | 
				
			||||||
    deps = [
 | 
					 | 
				
			||||||
        ":hyperparameters",
 | 
					 | 
				
			||||||
        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
					 | 
				
			||||||
    ],
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "image_classifier",
 | 
					    name = "image_classifier",
 | 
				
			||||||
    srcs = ["image_classifier.py"],
 | 
					    srcs = ["image_classifier.py"],
 | 
				
			||||||
| 
						 | 
					@ -104,7 +95,6 @@ py_library(
 | 
				
			||||||
        ":image_classifier_options",
 | 
					        ":image_classifier_options",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
        ":train_image_classifier_lib",
 | 
					 | 
				
			||||||
        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
					        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
					        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
					        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -35,4 +35,3 @@ del image_classifier
 | 
				
			||||||
del image_classifier_options
 | 
					del image_classifier_options
 | 
				
			||||||
del model_options
 | 
					del model_options
 | 
				
			||||||
del model_spec
 | 
					del model_spec
 | 
				
			||||||
del train_image_classifier_lib  # pylint: disable=undefined-variable
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,7 +28,6 @@ from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
 | 
					from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
 | 
					from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
 | 
					from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
 | 
					 | 
				
			||||||
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
 | 
					from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
 | 
				
			||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
 | 
					from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -57,6 +56,10 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
        mean_rgb=self._model_spec.mean_rgb,
 | 
					        mean_rgb=self._model_spec.mean_rgb,
 | 
				
			||||||
        stddev_rgb=self._model_spec.stddev_rgb,
 | 
					        stddev_rgb=self._model_spec.stddev_rgb,
 | 
				
			||||||
        use_augmentation=hparams.do_data_augmentation)
 | 
					        use_augmentation=hparams.do_data_augmentation)
 | 
				
			||||||
 | 
					    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
				
			||||||
 | 
					    self._loss_function = tf.keras.losses.CategoricalCrossentropy(
 | 
				
			||||||
 | 
					        label_smoothing=self._hparams.label_smoothing)
 | 
				
			||||||
 | 
					    self._metric_function = 'accuracy'
 | 
				
			||||||
    self._history = None  # Training history returned from `keras_model.fit`.
 | 
					    self._history = None  # Training history returned from `keras_model.fit`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
| 
						 | 
					@ -66,7 +69,7 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
      validation_data: classification_ds.ClassificationDataset,
 | 
					      validation_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
      options: image_classifier_options.ImageClassifierOptions,
 | 
					      options: image_classifier_options.ImageClassifierOptions,
 | 
				
			||||||
  ) -> 'ImageClassifier':
 | 
					  ) -> 'ImageClassifier':
 | 
				
			||||||
    """Creates and trains an image classifier.
 | 
					    """Creates and trains an ImageClassifier.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Loads data and trains the model based on data for image classification. If a
 | 
					    Loads data and trains the model based on data for image classification. If a
 | 
				
			||||||
    checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
 | 
					    checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
 | 
				
			||||||
| 
						 | 
					@ -93,58 +96,29 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
        label_names=train_data.label_names,
 | 
					        label_names=train_data.label_names,
 | 
				
			||||||
        hparams=options.hparams,
 | 
					        hparams=options.hparams,
 | 
				
			||||||
        model_options=options.model_options)
 | 
					        model_options=options.model_options)
 | 
				
			||||||
 | 
					    image_classifier._create_and_train_model(train_data, validation_data)
 | 
				
			||||||
    image_classifier._create_model()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    tf.compat.v1.logging.info('Training the models...')
 | 
					 | 
				
			||||||
    image_classifier._train(
 | 
					 | 
				
			||||||
        train_data=train_data, validation_data=validation_data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return image_classifier
 | 
					    return image_classifier
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # TODO: Migrate to the shared training library of Model Maker.
 | 
					  def _create_and_train_model(
 | 
				
			||||||
  def _train(self, train_data: classification_ds.ClassificationDataset,
 | 
					      self, train_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
      validation_data: classification_ds.ClassificationDataset):
 | 
					      validation_data: classification_ds.ClassificationDataset):
 | 
				
			||||||
    """Trains the model with input train_data.
 | 
					    """Creates and trains the model and optimizer.
 | 
				
			||||||
 | 
					 | 
				
			||||||
    The training results are recorded by a self._history object returned by
 | 
					 | 
				
			||||||
    tf.keras.Model.fit().
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
      train_data: Training data.
 | 
					      train_data: Training data.
 | 
				
			||||||
      validation_data: Validation data.
 | 
					      validation_data: Validation data.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    self._create_model()
 | 
				
			||||||
    tf.compat.v1.logging.info('Training the models...')
 | 
					    self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
 | 
				
			||||||
    hparams = self._hparams
 | 
					        steps_per_epoch=self._hparams.steps_per_epoch,
 | 
				
			||||||
    if len(train_data) < hparams.batch_size:
 | 
					        batch_size=self._hparams.batch_size,
 | 
				
			||||||
      raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
 | 
					 | 
				
			||||||
                       'than batch_size (%d). To solve this problem, set '
 | 
					 | 
				
			||||||
                       'the batch_size smaller or increase the size of the '
 | 
					 | 
				
			||||||
                       'train_data.' % (len(train_data), hparams.batch_size))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    train_dataset = train_data.gen_tf_dataset(
 | 
					 | 
				
			||||||
        batch_size=hparams.batch_size,
 | 
					 | 
				
			||||||
        is_training=True,
 | 
					 | 
				
			||||||
        shuffle=self._shuffle,
 | 
					 | 
				
			||||||
        preprocess=self._preprocess)
 | 
					 | 
				
			||||||
    hparams.steps_per_epoch = model_util.get_steps_per_epoch(
 | 
					 | 
				
			||||||
        steps_per_epoch=hparams.steps_per_epoch,
 | 
					 | 
				
			||||||
        batch_size=hparams.batch_size,
 | 
					 | 
				
			||||||
        train_data=train_data)
 | 
					        train_data=train_data)
 | 
				
			||||||
    train_dataset = train_dataset.take(count=hparams.steps_per_epoch)
 | 
					    self._optimizer = self._create_optimizer()
 | 
				
			||||||
 | 
					    self._train_model(
 | 
				
			||||||
    validation_dataset = validation_data.gen_tf_dataset(
 | 
					        train_data=train_data,
 | 
				
			||||||
        batch_size=hparams.batch_size,
 | 
					        validation_data=validation_data,
 | 
				
			||||||
        is_training=False,
 | 
					        preprocessor=self._preprocess,
 | 
				
			||||||
        preprocess=self._preprocess)
 | 
					        checkpoint_path=os.path.join(self._hparams.export_dir, 'checkpoint'))
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Train the model.
 | 
					 | 
				
			||||||
    self._history = train_image_classifier_lib.train_model(
 | 
					 | 
				
			||||||
        model=self._model,
 | 
					 | 
				
			||||||
        hparams=hparams,
 | 
					 | 
				
			||||||
        train_ds=train_dataset,
 | 
					 | 
				
			||||||
        validation_ds=validation_dataset)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _create_model(self):
 | 
					  def _create_model(self):
 | 
				
			||||||
    """Creates the classifier model from TFHub pretrained models."""
 | 
					    """Creates the classifier model from TFHub pretrained models."""
 | 
				
			||||||
| 
						 | 
					@ -198,3 +172,33 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
    model_util.save_tflite(tflite_model_with_metadata, tflite_file)
 | 
					    model_util.save_tflite(tflite_model_with_metadata, tflite_file)
 | 
				
			||||||
    with open(metadata_file, 'w') as f:
 | 
					    with open(metadata_file, 'w') as f:
 | 
				
			||||||
      f.write(metadata_json)
 | 
					      f.write(metadata_json)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
 | 
				
			||||||
 | 
					    """Creates an optimizer with learning rate schedule.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Uses Keras CosineDecay schedule for the learning rate by default.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A tf.keras.optimizers.Optimizer for model training.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Learning rate is linear to batch size.
 | 
				
			||||||
 | 
					    init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Get decay steps.
 | 
				
			||||||
 | 
					    total_training_steps = self._hparams.steps_per_epoch * self._hparams.epochs
 | 
				
			||||||
 | 
					    default_decay_steps = (
 | 
				
			||||||
 | 
					        self._hparams.decay_samples // self._hparams.batch_size)
 | 
				
			||||||
 | 
					    decay_steps = max(total_training_steps, default_decay_steps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    learning_rate_fn = tf.keras.experimental.CosineDecay(
 | 
				
			||||||
 | 
					        initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
 | 
				
			||||||
 | 
					    warmup_steps = self._hparams.warmup_epochs * self._hparams.steps_per_epoch
 | 
				
			||||||
 | 
					    if warmup_steps:
 | 
				
			||||||
 | 
					      learning_rate_fn = model_util.WarmUp(
 | 
				
			||||||
 | 
					          initial_learning_rate=init_lr,
 | 
				
			||||||
 | 
					          decay_schedule_fn=learning_rate_fn,
 | 
				
			||||||
 | 
					          warmup_steps=warmup_steps)
 | 
				
			||||||
 | 
					    optimizer = tf.keras.optimizers.RMSprop(
 | 
				
			||||||
 | 
					        learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return optimizer
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,102 +0,0 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
# You may obtain a copy of the License at
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
# limitations under the License.
 | 
					 | 
				
			||||||
"""Library to train model."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import tensorflow as tf
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import model_util
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _create_optimizer(init_lr: float, decay_steps: int,
 | 
					 | 
				
			||||||
                      warmup_steps: int) -> tf.keras.optimizers.Optimizer:
 | 
					 | 
				
			||||||
  """Creates an optimizer with learning rate schedule.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Uses Keras CosineDecay schedule for the learning rate by default.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Args:
 | 
					 | 
				
			||||||
    init_lr: Initial learning rate.
 | 
					 | 
				
			||||||
    decay_steps: Number of steps to decay over.
 | 
					 | 
				
			||||||
    warmup_steps: Number of steps to do warmup for.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Returns:
 | 
					 | 
				
			||||||
    A tf.keras.optimizers.Optimizer for model training.
 | 
					 | 
				
			||||||
  """
 | 
					 | 
				
			||||||
  learning_rate_fn = tf.keras.experimental.CosineDecay(
 | 
					 | 
				
			||||||
      initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
 | 
					 | 
				
			||||||
  if warmup_steps:
 | 
					 | 
				
			||||||
    learning_rate_fn = model_util.WarmUp(
 | 
					 | 
				
			||||||
        initial_learning_rate=init_lr,
 | 
					 | 
				
			||||||
        decay_schedule_fn=learning_rate_fn,
 | 
					 | 
				
			||||||
        warmup_steps=warmup_steps)
 | 
					 | 
				
			||||||
  optimizer = tf.keras.optimizers.RMSprop(
 | 
					 | 
				
			||||||
      learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  return optimizer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def train_model(model: tf.keras.Model, hparams: hp.HParams,
 | 
					 | 
				
			||||||
                train_ds: tf.data.Dataset,
 | 
					 | 
				
			||||||
                validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
 | 
					 | 
				
			||||||
  """Trains model with the input data and hyperparameters.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Args:
 | 
					 | 
				
			||||||
    model: Input tf.keras.Model.
 | 
					 | 
				
			||||||
    hparams: Hyperparameters for training image classifier.
 | 
					 | 
				
			||||||
    train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit().
 | 
					 | 
				
			||||||
    validation_ds: tf.data.Dataset, validation data to be fed in
 | 
					 | 
				
			||||||
      tf.keras.Model.fit().
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Returns:
 | 
					 | 
				
			||||||
    The tf.keras.callbacks.History object returned by tf.keras.Model.fit().
 | 
					 | 
				
			||||||
  """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  # Learning rate is linear to batch size.
 | 
					 | 
				
			||||||
  learning_rate = hparams.learning_rate * hparams.batch_size / 256
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  # Get decay steps.
 | 
					 | 
				
			||||||
  # NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
 | 
					 | 
				
			||||||
  total_training_steps = hparams.steps_per_epoch * hparams.epochs
 | 
					 | 
				
			||||||
  default_decay_steps = hparams.decay_samples // hparams.batch_size
 | 
					 | 
				
			||||||
  decay_steps = max(total_training_steps, default_decay_steps)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch
 | 
					 | 
				
			||||||
  optimizer = _create_optimizer(
 | 
					 | 
				
			||||||
      init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  loss = tf.keras.losses.CategoricalCrossentropy(
 | 
					 | 
				
			||||||
      label_smoothing=hparams.label_smoothing)
 | 
					 | 
				
			||||||
  model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  summary_dir = os.path.join(hparams.export_dir, 'summaries')
 | 
					 | 
				
			||||||
  summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
 | 
					 | 
				
			||||||
  # Save checkpoint every 5 epochs.
 | 
					 | 
				
			||||||
  checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint')
 | 
					 | 
				
			||||||
  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
 | 
					 | 
				
			||||||
      os.path.join(checkpoint_path, 'model-{epoch:04d}'),
 | 
					 | 
				
			||||||
      save_weights_only=True,
 | 
					 | 
				
			||||||
      period=5)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
 | 
					 | 
				
			||||||
  if latest_checkpoint:
 | 
					 | 
				
			||||||
    print(f'Resuming from {latest_checkpoint}')
 | 
					 | 
				
			||||||
    model.load_weights(latest_checkpoint)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  # Train the model.
 | 
					 | 
				
			||||||
  return model.fit(
 | 
					 | 
				
			||||||
      x=train_ds,
 | 
					 | 
				
			||||||
      epochs=hparams.epochs,
 | 
					 | 
				
			||||||
      validation_data=validation_ds,
 | 
					 | 
				
			||||||
      callbacks=[summary_callback, checkpoint_callback])
 | 
					 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user