Support continual training image classifier from saved checkpoint files.
PiperOrigin-RevId: 487057612
This commit is contained in:
parent
b3d19fa1af
commit
0917e8cb8e
|
@ -68,7 +68,10 @@ class ImageClassifier(classifier.Classifier):
|
|||
) -> 'ImageClassifier':
|
||||
"""Creates and trains an image classifier.
|
||||
|
||||
Loads data and trains the model based on data for image classification.
|
||||
Loads data and trains the model based on data for image classification. If a
|
||||
checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
|
||||
directory, the training process will load the weight from the checkpoint
|
||||
file for continual training.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
|
|
|
@ -13,10 +13,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
import filecmp
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from unittest import mock
|
||||
from unittest import mock as unittest_mock
|
||||
from absl.testing import parameterized
|
||||
import mock
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -63,14 +66,20 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=image_classifier.SupportedModels.MOBILENET_V2,
|
||||
hparams=image_classifier.HParams(
|
||||
epochs=1, batch_size=1, shuffle=True))),
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
export_dir=tempfile.mkdtemp()))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite0',
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
||||
hparams=image_classifier.HParams(
|
||||
epochs=1, batch_size=1, shuffle=True))),
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
export_dir=tempfile.mkdtemp()))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite0_change_dropout_rate',
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
|
@ -78,21 +87,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
||||
model_options=image_classifier.ModelOptions(dropout_rate=0.1),
|
||||
hparams=image_classifier.HParams(
|
||||
epochs=1, batch_size=1, shuffle=True))),
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
export_dir=tempfile.mkdtemp()))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite2',
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE2),
|
||||
hparams=image_classifier.HParams(
|
||||
epochs=1, batch_size=1, shuffle=True))),
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
export_dir=tempfile.mkdtemp()))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite4',
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE4),
|
||||
hparams=image_classifier.HParams(
|
||||
epochs=1, batch_size=1, shuffle=True))),
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
export_dir=tempfile.mkdtemp()))),
|
||||
)
|
||||
def test_create_and_train_model(
|
||||
self, options: image_classifier.ImageClassifierOptions):
|
||||
|
@ -117,16 +135,35 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
||||
|
||||
def test_continual_training_by_loading_checkpoint(self):
|
||||
mock_stdout = io.StringIO()
|
||||
with mock.patch('sys.stdout', mock_stdout):
|
||||
options = image_classifier.ImageClassifierOptions(
|
||||
supported_model=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||
hparams=image_classifier.HParams(
|
||||
epochs=5, batch_size=1, shuffle=True))
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=options)
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=options)
|
||||
self._test_accuracy(model)
|
||||
|
||||
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
|
||||
|
||||
def _test_accuracy(self, model, threshold=0.0):
|
||||
_, accuracy = model.evaluate(self._test_data)
|
||||
self.assertGreaterEqual(accuracy, threshold)
|
||||
|
||||
@mock.patch.object(
|
||||
@unittest_mock.patch.object(
|
||||
image_classifier.hyperparameters,
|
||||
'HParams',
|
||||
autospec=True,
|
||||
return_value=image_classifier.HParams(epochs=1))
|
||||
@mock.patch.object(
|
||||
@unittest_mock.patch.object(
|
||||
image_classifier.model_options,
|
||||
'ImageClassifierModelOptions',
|
||||
autospec=True,
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""Library to train model."""
|
||||
|
||||
import os
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
|
@ -78,11 +79,24 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
|||
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||
label_smoothing=hparams.label_smoothing)
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||
callbacks = model_util.get_default_callbacks(export_dir=hparams.export_dir)
|
||||
|
||||
summary_dir = os.path.join(hparams.export_dir, 'summaries')
|
||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||
# Save checkpoint every 5 epochs.
|
||||
checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||
save_weights_only=True,
|
||||
period=5)
|
||||
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||
if latest_checkpoint:
|
||||
print(f'Resuming from {latest_checkpoint}')
|
||||
model.load_weights(latest_checkpoint)
|
||||
|
||||
# Train the model.
|
||||
return model.fit(
|
||||
x=train_ds,
|
||||
epochs=hparams.epochs,
|
||||
validation_data=validation_ds,
|
||||
callbacks=callbacks)
|
||||
callbacks=[summary_callback, checkpoint_callback])
|
||||
|
|
Loading…
Reference in New Issue
Block a user