diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 29b3025d8..f6edbeab4 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -68,7 +68,10 @@ class ImageClassifier(classifier.Classifier): ) -> 'ImageClassifier': """Creates and trains an image classifier. - Loads data and trains the model based on data for image classification. + Loads data and trains the model based on data for image classification. If a + checkpoint file exists in the {options.hparams.export_dir}/checkpoint/ + directory, the training process will load the weight from the checkpoint + file for continual training. Args: train_data: Training data. diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 8446df18e..252659edc 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -13,10 +13,13 @@ # limitations under the License. import filecmp +import io import os +import tempfile -from unittest import mock +from unittest import mock as unittest_mock from absl.testing import parameterized +import mock import numpy as np import tensorflow as tf @@ -63,14 +66,20 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): options=image_classifier.ImageClassifierOptions( supported_model=image_classifier.SupportedModels.MOBILENET_V2, hparams=image_classifier.HParams( - epochs=1, batch_size=1, shuffle=True))), + epochs=1, + batch_size=1, + shuffle=True, + export_dir=tempfile.mkdtemp()))), dict( testcase_name='efficientnet_lite0', options=image_classifier.ImageClassifierOptions( supported_model=( image_classifier.SupportedModels.EFFICIENTNET_LITE0), hparams=image_classifier.HParams( - epochs=1, batch_size=1, shuffle=True))), + epochs=1, + batch_size=1, + shuffle=True, + export_dir=tempfile.mkdtemp()))), dict( testcase_name='efficientnet_lite0_change_dropout_rate', options=image_classifier.ImageClassifierOptions( @@ -78,21 +87,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): image_classifier.SupportedModels.EFFICIENTNET_LITE0), model_options=image_classifier.ModelOptions(dropout_rate=0.1), hparams=image_classifier.HParams( - epochs=1, batch_size=1, shuffle=True))), + epochs=1, + batch_size=1, + shuffle=True, + export_dir=tempfile.mkdtemp()))), dict( testcase_name='efficientnet_lite2', options=image_classifier.ImageClassifierOptions( supported_model=( image_classifier.SupportedModels.EFFICIENTNET_LITE2), hparams=image_classifier.HParams( - epochs=1, batch_size=1, shuffle=True))), + epochs=1, + batch_size=1, + shuffle=True, + export_dir=tempfile.mkdtemp()))), dict( testcase_name='efficientnet_lite4', options=image_classifier.ImageClassifierOptions( supported_model=( image_classifier.SupportedModels.EFFICIENTNET_LITE4), hparams=image_classifier.HParams( - epochs=1, batch_size=1, shuffle=True))), + epochs=1, + batch_size=1, + shuffle=True, + export_dir=tempfile.mkdtemp()))), ) def test_create_and_train_model( self, options: image_classifier.ImageClassifierOptions): @@ -117,16 +135,35 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) + def test_continual_training_by_loading_checkpoint(self): + mock_stdout = io.StringIO() + with mock.patch('sys.stdout', mock_stdout): + options = image_classifier.ImageClassifierOptions( + supported_model=image_classifier.SupportedModels.EFFICIENTNET_LITE0, + hparams=image_classifier.HParams( + epochs=5, batch_size=1, shuffle=True)) + model = image_classifier.ImageClassifier.create( + train_data=self._train_data, + validation_data=self._test_data, + options=options) + model = image_classifier.ImageClassifier.create( + train_data=self._train_data, + validation_data=self._test_data, + options=options) + self._test_accuracy(model) + + self.assertRegex(mock_stdout.getvalue(), 'Resuming from') + def _test_accuracy(self, model, threshold=0.0): _, accuracy = model.evaluate(self._test_data) self.assertGreaterEqual(accuracy, threshold) - @mock.patch.object( + @unittest_mock.patch.object( image_classifier.hyperparameters, 'HParams', autospec=True, return_value=image_classifier.HParams(epochs=1)) - @mock.patch.object( + @unittest_mock.patch.object( image_classifier.model_options, 'ImageClassifierModelOptions', autospec=True, diff --git a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py index 4adddefeb..c5b28cff5 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py +++ b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py @@ -13,6 +13,7 @@ # limitations under the License. """Library to train model.""" +import os import tensorflow as tf from mediapipe.model_maker.python.core.utils import model_util @@ -78,11 +79,24 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams, loss = tf.keras.losses.CategoricalCrossentropy( label_smoothing=hparams.label_smoothing) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) - callbacks = model_util.get_default_callbacks(export_dir=hparams.export_dir) + + summary_dir = os.path.join(hparams.export_dir, 'summaries') + summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) + # Save checkpoint every 5 epochs. + checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint') + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + os.path.join(checkpoint_path, 'model-{epoch:04d}'), + save_weights_only=True, + period=5) + + latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path) + if latest_checkpoint: + print(f'Resuming from {latest_checkpoint}') + model.load_weights(latest_checkpoint) # Train the model. return model.fit( x=train_ds, epochs=hparams.epochs, validation_data=validation_ds, - callbacks=callbacks) + callbacks=[summary_callback, checkpoint_callback])