Support continual training image classifier from saved checkpoint files.
PiperOrigin-RevId: 487057612
This commit is contained in:
		
							parent
							
								
									b3d19fa1af
								
							
						
					
					
						commit
						0917e8cb8e
					
				| 
						 | 
				
			
			@ -68,7 +68,10 @@ class ImageClassifier(classifier.Classifier):
 | 
			
		|||
  ) -> 'ImageClassifier':
 | 
			
		||||
    """Creates and trains an image classifier.
 | 
			
		||||
 | 
			
		||||
    Loads data and trains the model based on data for image classification.
 | 
			
		||||
    Loads data and trains the model based on data for image classification. If a
 | 
			
		||||
    checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
 | 
			
		||||
    directory, the training process will load the weight from the checkpoint
 | 
			
		||||
    file for continual training.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      train_data: Training data.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,10 +13,13 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import filecmp
 | 
			
		||||
import io
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
 | 
			
		||||
from unittest import mock
 | 
			
		||||
from unittest import mock as unittest_mock
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
import mock
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -63,14 +66,20 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
 | 
			
		|||
          options=image_classifier.ImageClassifierOptions(
 | 
			
		||||
              supported_model=image_classifier.SupportedModels.MOBILENET_V2,
 | 
			
		||||
              hparams=image_classifier.HParams(
 | 
			
		||||
                  epochs=1, batch_size=1, shuffle=True))),
 | 
			
		||||
                  epochs=1,
 | 
			
		||||
                  batch_size=1,
 | 
			
		||||
                  shuffle=True,
 | 
			
		||||
                  export_dir=tempfile.mkdtemp()))),
 | 
			
		||||
      dict(
 | 
			
		||||
          testcase_name='efficientnet_lite0',
 | 
			
		||||
          options=image_classifier.ImageClassifierOptions(
 | 
			
		||||
              supported_model=(
 | 
			
		||||
                  image_classifier.SupportedModels.EFFICIENTNET_LITE0),
 | 
			
		||||
              hparams=image_classifier.HParams(
 | 
			
		||||
                  epochs=1, batch_size=1, shuffle=True))),
 | 
			
		||||
                  epochs=1,
 | 
			
		||||
                  batch_size=1,
 | 
			
		||||
                  shuffle=True,
 | 
			
		||||
                  export_dir=tempfile.mkdtemp()))),
 | 
			
		||||
      dict(
 | 
			
		||||
          testcase_name='efficientnet_lite0_change_dropout_rate',
 | 
			
		||||
          options=image_classifier.ImageClassifierOptions(
 | 
			
		||||
| 
						 | 
				
			
			@ -78,21 +87,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
 | 
			
		|||
                  image_classifier.SupportedModels.EFFICIENTNET_LITE0),
 | 
			
		||||
              model_options=image_classifier.ModelOptions(dropout_rate=0.1),
 | 
			
		||||
              hparams=image_classifier.HParams(
 | 
			
		||||
                  epochs=1, batch_size=1, shuffle=True))),
 | 
			
		||||
                  epochs=1,
 | 
			
		||||
                  batch_size=1,
 | 
			
		||||
                  shuffle=True,
 | 
			
		||||
                  export_dir=tempfile.mkdtemp()))),
 | 
			
		||||
      dict(
 | 
			
		||||
          testcase_name='efficientnet_lite2',
 | 
			
		||||
          options=image_classifier.ImageClassifierOptions(
 | 
			
		||||
              supported_model=(
 | 
			
		||||
                  image_classifier.SupportedModels.EFFICIENTNET_LITE2),
 | 
			
		||||
              hparams=image_classifier.HParams(
 | 
			
		||||
                  epochs=1, batch_size=1, shuffle=True))),
 | 
			
		||||
                  epochs=1,
 | 
			
		||||
                  batch_size=1,
 | 
			
		||||
                  shuffle=True,
 | 
			
		||||
                  export_dir=tempfile.mkdtemp()))),
 | 
			
		||||
      dict(
 | 
			
		||||
          testcase_name='efficientnet_lite4',
 | 
			
		||||
          options=image_classifier.ImageClassifierOptions(
 | 
			
		||||
              supported_model=(
 | 
			
		||||
                  image_classifier.SupportedModels.EFFICIENTNET_LITE4),
 | 
			
		||||
              hparams=image_classifier.HParams(
 | 
			
		||||
                  epochs=1, batch_size=1, shuffle=True))),
 | 
			
		||||
                  epochs=1,
 | 
			
		||||
                  batch_size=1,
 | 
			
		||||
                  shuffle=True,
 | 
			
		||||
                  export_dir=tempfile.mkdtemp()))),
 | 
			
		||||
  )
 | 
			
		||||
  def test_create_and_train_model(
 | 
			
		||||
      self, options: image_classifier.ImageClassifierOptions):
 | 
			
		||||
| 
						 | 
				
			
			@ -117,16 +135,35 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
 | 
			
		|||
    self.assertGreater(os.path.getsize(output_metadata_file), 0)
 | 
			
		||||
    self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
 | 
			
		||||
 | 
			
		||||
  def test_continual_training_by_loading_checkpoint(self):
 | 
			
		||||
    mock_stdout = io.StringIO()
 | 
			
		||||
    with mock.patch('sys.stdout', mock_stdout):
 | 
			
		||||
      options = image_classifier.ImageClassifierOptions(
 | 
			
		||||
          supported_model=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
 | 
			
		||||
          hparams=image_classifier.HParams(
 | 
			
		||||
              epochs=5, batch_size=1, shuffle=True))
 | 
			
		||||
      model = image_classifier.ImageClassifier.create(
 | 
			
		||||
          train_data=self._train_data,
 | 
			
		||||
          validation_data=self._test_data,
 | 
			
		||||
          options=options)
 | 
			
		||||
      model = image_classifier.ImageClassifier.create(
 | 
			
		||||
          train_data=self._train_data,
 | 
			
		||||
          validation_data=self._test_data,
 | 
			
		||||
          options=options)
 | 
			
		||||
      self._test_accuracy(model)
 | 
			
		||||
 | 
			
		||||
    self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
 | 
			
		||||
 | 
			
		||||
  def _test_accuracy(self, model, threshold=0.0):
 | 
			
		||||
    _, accuracy = model.evaluate(self._test_data)
 | 
			
		||||
    self.assertGreaterEqual(accuracy, threshold)
 | 
			
		||||
 | 
			
		||||
  @mock.patch.object(
 | 
			
		||||
  @unittest_mock.patch.object(
 | 
			
		||||
      image_classifier.hyperparameters,
 | 
			
		||||
      'HParams',
 | 
			
		||||
      autospec=True,
 | 
			
		||||
      return_value=image_classifier.HParams(epochs=1))
 | 
			
		||||
  @mock.patch.object(
 | 
			
		||||
  @unittest_mock.patch.object(
 | 
			
		||||
      image_classifier.model_options,
 | 
			
		||||
      'ImageClassifierModelOptions',
 | 
			
		||||
      autospec=True,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
"""Library to train model."""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import model_util
 | 
			
		||||
| 
						 | 
				
			
			@ -78,11 +79,24 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
 | 
			
		|||
  loss = tf.keras.losses.CategoricalCrossentropy(
 | 
			
		||||
      label_smoothing=hparams.label_smoothing)
 | 
			
		||||
  model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
 | 
			
		||||
  callbacks = model_util.get_default_callbacks(export_dir=hparams.export_dir)
 | 
			
		||||
 | 
			
		||||
  summary_dir = os.path.join(hparams.export_dir, 'summaries')
 | 
			
		||||
  summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
 | 
			
		||||
  # Save checkpoint every 5 epochs.
 | 
			
		||||
  checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint')
 | 
			
		||||
  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
 | 
			
		||||
      os.path.join(checkpoint_path, 'model-{epoch:04d}'),
 | 
			
		||||
      save_weights_only=True,
 | 
			
		||||
      period=5)
 | 
			
		||||
 | 
			
		||||
  latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
 | 
			
		||||
  if latest_checkpoint:
 | 
			
		||||
    print(f'Resuming from {latest_checkpoint}')
 | 
			
		||||
    model.load_weights(latest_checkpoint)
 | 
			
		||||
 | 
			
		||||
  # Train the model.
 | 
			
		||||
  return model.fit(
 | 
			
		||||
      x=train_ds,
 | 
			
		||||
      epochs=hparams.epochs,
 | 
			
		||||
      validation_data=validation_ds,
 | 
			
		||||
      callbacks=callbacks)
 | 
			
		||||
      callbacks=[summary_callback, checkpoint_callback])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user