Support continual training image classifier from saved checkpoint files.

PiperOrigin-RevId: 487057612
This commit is contained in:
MediaPipe Team 2022-11-08 14:47:19 -08:00 committed by Copybara-Service
parent b3d19fa1af
commit 0917e8cb8e
3 changed files with 65 additions and 11 deletions

View File

@ -68,7 +68,10 @@ class ImageClassifier(classifier.Classifier):
) -> 'ImageClassifier': ) -> 'ImageClassifier':
"""Creates and trains an image classifier. """Creates and trains an image classifier.
Loads data and trains the model based on data for image classification. Loads data and trains the model based on data for image classification. If a
checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
directory, the training process will load the weight from the checkpoint
file for continual training.
Args: Args:
train_data: Training data. train_data: Training data.

View File

@ -13,10 +13,13 @@
# limitations under the License. # limitations under the License.
import filecmp import filecmp
import io
import os import os
import tempfile
from unittest import mock from unittest import mock as unittest_mock
from absl.testing import parameterized from absl.testing import parameterized
import mock
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -63,14 +66,20 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
options=image_classifier.ImageClassifierOptions( options=image_classifier.ImageClassifierOptions(
supported_model=image_classifier.SupportedModels.MOBILENET_V2, supported_model=image_classifier.SupportedModels.MOBILENET_V2,
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))), epochs=1,
batch_size=1,
shuffle=True,
export_dir=tempfile.mkdtemp()))),
dict( dict(
testcase_name='efficientnet_lite0', testcase_name='efficientnet_lite0',
options=image_classifier.ImageClassifierOptions( options=image_classifier.ImageClassifierOptions(
supported_model=( supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE0), image_classifier.SupportedModels.EFFICIENTNET_LITE0),
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))), epochs=1,
batch_size=1,
shuffle=True,
export_dir=tempfile.mkdtemp()))),
dict( dict(
testcase_name='efficientnet_lite0_change_dropout_rate', testcase_name='efficientnet_lite0_change_dropout_rate',
options=image_classifier.ImageClassifierOptions( options=image_classifier.ImageClassifierOptions(
@ -78,21 +87,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
image_classifier.SupportedModels.EFFICIENTNET_LITE0), image_classifier.SupportedModels.EFFICIENTNET_LITE0),
model_options=image_classifier.ModelOptions(dropout_rate=0.1), model_options=image_classifier.ModelOptions(dropout_rate=0.1),
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))), epochs=1,
batch_size=1,
shuffle=True,
export_dir=tempfile.mkdtemp()))),
dict( dict(
testcase_name='efficientnet_lite2', testcase_name='efficientnet_lite2',
options=image_classifier.ImageClassifierOptions( options=image_classifier.ImageClassifierOptions(
supported_model=( supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE2), image_classifier.SupportedModels.EFFICIENTNET_LITE2),
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))), epochs=1,
batch_size=1,
shuffle=True,
export_dir=tempfile.mkdtemp()))),
dict( dict(
testcase_name='efficientnet_lite4', testcase_name='efficientnet_lite4',
options=image_classifier.ImageClassifierOptions( options=image_classifier.ImageClassifierOptions(
supported_model=( supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE4), image_classifier.SupportedModels.EFFICIENTNET_LITE4),
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))), epochs=1,
batch_size=1,
shuffle=True,
export_dir=tempfile.mkdtemp()))),
) )
def test_create_and_train_model( def test_create_and_train_model(
self, options: image_classifier.ImageClassifierOptions): self, options: image_classifier.ImageClassifierOptions):
@ -117,16 +135,35 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertGreater(os.path.getsize(output_metadata_file), 0)
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
def test_continual_training_by_loading_checkpoint(self):
mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout):
options = image_classifier.ImageClassifierOptions(
supported_model=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
hparams=image_classifier.HParams(
epochs=5, batch_size=1, shuffle=True))
model = image_classifier.ImageClassifier.create(
train_data=self._train_data,
validation_data=self._test_data,
options=options)
model = image_classifier.ImageClassifier.create(
train_data=self._train_data,
validation_data=self._test_data,
options=options)
self._test_accuracy(model)
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
def _test_accuracy(self, model, threshold=0.0): def _test_accuracy(self, model, threshold=0.0):
_, accuracy = model.evaluate(self._test_data) _, accuracy = model.evaluate(self._test_data)
self.assertGreaterEqual(accuracy, threshold) self.assertGreaterEqual(accuracy, threshold)
@mock.patch.object( @unittest_mock.patch.object(
image_classifier.hyperparameters, image_classifier.hyperparameters,
'HParams', 'HParams',
autospec=True, autospec=True,
return_value=image_classifier.HParams(epochs=1)) return_value=image_classifier.HParams(epochs=1))
@mock.patch.object( @unittest_mock.patch.object(
image_classifier.model_options, image_classifier.model_options,
'ImageClassifierModelOptions', 'ImageClassifierModelOptions',
autospec=True, autospec=True,

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Library to train model.""" """Library to train model."""
import os
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
@ -78,11 +79,24 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
loss = tf.keras.losses.CategoricalCrossentropy( loss = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=hparams.label_smoothing) label_smoothing=hparams.label_smoothing)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
callbacks = model_util.get_default_callbacks(export_dir=hparams.export_dir)
summary_dir = os.path.join(hparams.export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 5 epochs.
checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True,
period=5)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
if latest_checkpoint:
print(f'Resuming from {latest_checkpoint}')
model.load_weights(latest_checkpoint)
# Train the model. # Train the model.
return model.fit( return model.fit(
x=train_ds, x=train_ds,
epochs=hparams.epochs, epochs=hparams.epochs,
validation_data=validation_ds, validation_data=validation_ds,
callbacks=callbacks) callbacks=[summary_callback, checkpoint_callback])