Support continual training image classifier from saved checkpoint files.
PiperOrigin-RevId: 487057612
This commit is contained in:
parent
b3d19fa1af
commit
0917e8cb8e
|
@ -68,7 +68,10 @@ class ImageClassifier(classifier.Classifier):
|
||||||
) -> 'ImageClassifier':
|
) -> 'ImageClassifier':
|
||||||
"""Creates and trains an image classifier.
|
"""Creates and trains an image classifier.
|
||||||
|
|
||||||
Loads data and trains the model based on data for image classification.
|
Loads data and trains the model based on data for image classification. If a
|
||||||
|
checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
|
||||||
|
directory, the training process will load the weight from the checkpoint
|
||||||
|
file for continual training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_data: Training data.
|
train_data: Training data.
|
||||||
|
|
|
@ -13,10 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import filecmp
|
import filecmp
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from unittest import mock
|
from unittest import mock as unittest_mock
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
import mock
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -63,14 +66,20 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
options=image_classifier.ImageClassifierOptions(
|
options=image_classifier.ImageClassifierOptions(
|
||||||
supported_model=image_classifier.SupportedModels.MOBILENET_V2,
|
supported_model=image_classifier.SupportedModels.MOBILENET_V2,
|
||||||
hparams=image_classifier.HParams(
|
hparams=image_classifier.HParams(
|
||||||
epochs=1, batch_size=1, shuffle=True))),
|
epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
export_dir=tempfile.mkdtemp()))),
|
||||||
dict(
|
dict(
|
||||||
testcase_name='efficientnet_lite0',
|
testcase_name='efficientnet_lite0',
|
||||||
options=image_classifier.ImageClassifierOptions(
|
options=image_classifier.ImageClassifierOptions(
|
||||||
supported_model=(
|
supported_model=(
|
||||||
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
||||||
hparams=image_classifier.HParams(
|
hparams=image_classifier.HParams(
|
||||||
epochs=1, batch_size=1, shuffle=True))),
|
epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
export_dir=tempfile.mkdtemp()))),
|
||||||
dict(
|
dict(
|
||||||
testcase_name='efficientnet_lite0_change_dropout_rate',
|
testcase_name='efficientnet_lite0_change_dropout_rate',
|
||||||
options=image_classifier.ImageClassifierOptions(
|
options=image_classifier.ImageClassifierOptions(
|
||||||
|
@ -78,21 +87,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
||||||
model_options=image_classifier.ModelOptions(dropout_rate=0.1),
|
model_options=image_classifier.ModelOptions(dropout_rate=0.1),
|
||||||
hparams=image_classifier.HParams(
|
hparams=image_classifier.HParams(
|
||||||
epochs=1, batch_size=1, shuffle=True))),
|
epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
export_dir=tempfile.mkdtemp()))),
|
||||||
dict(
|
dict(
|
||||||
testcase_name='efficientnet_lite2',
|
testcase_name='efficientnet_lite2',
|
||||||
options=image_classifier.ImageClassifierOptions(
|
options=image_classifier.ImageClassifierOptions(
|
||||||
supported_model=(
|
supported_model=(
|
||||||
image_classifier.SupportedModels.EFFICIENTNET_LITE2),
|
image_classifier.SupportedModels.EFFICIENTNET_LITE2),
|
||||||
hparams=image_classifier.HParams(
|
hparams=image_classifier.HParams(
|
||||||
epochs=1, batch_size=1, shuffle=True))),
|
epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
export_dir=tempfile.mkdtemp()))),
|
||||||
dict(
|
dict(
|
||||||
testcase_name='efficientnet_lite4',
|
testcase_name='efficientnet_lite4',
|
||||||
options=image_classifier.ImageClassifierOptions(
|
options=image_classifier.ImageClassifierOptions(
|
||||||
supported_model=(
|
supported_model=(
|
||||||
image_classifier.SupportedModels.EFFICIENTNET_LITE4),
|
image_classifier.SupportedModels.EFFICIENTNET_LITE4),
|
||||||
hparams=image_classifier.HParams(
|
hparams=image_classifier.HParams(
|
||||||
epochs=1, batch_size=1, shuffle=True))),
|
epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
export_dir=tempfile.mkdtemp()))),
|
||||||
)
|
)
|
||||||
def test_create_and_train_model(
|
def test_create_and_train_model(
|
||||||
self, options: image_classifier.ImageClassifierOptions):
|
self, options: image_classifier.ImageClassifierOptions):
|
||||||
|
@ -117,16 +135,35 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
||||||
|
|
||||||
|
def test_continual_training_by_loading_checkpoint(self):
|
||||||
|
mock_stdout = io.StringIO()
|
||||||
|
with mock.patch('sys.stdout', mock_stdout):
|
||||||
|
options = image_classifier.ImageClassifierOptions(
|
||||||
|
supported_model=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
epochs=5, batch_size=1, shuffle=True))
|
||||||
|
model = image_classifier.ImageClassifier.create(
|
||||||
|
train_data=self._train_data,
|
||||||
|
validation_data=self._test_data,
|
||||||
|
options=options)
|
||||||
|
model = image_classifier.ImageClassifier.create(
|
||||||
|
train_data=self._train_data,
|
||||||
|
validation_data=self._test_data,
|
||||||
|
options=options)
|
||||||
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
|
||||||
|
|
||||||
def _test_accuracy(self, model, threshold=0.0):
|
def _test_accuracy(self, model, threshold=0.0):
|
||||||
_, accuracy = model.evaluate(self._test_data)
|
_, accuracy = model.evaluate(self._test_data)
|
||||||
self.assertGreaterEqual(accuracy, threshold)
|
self.assertGreaterEqual(accuracy, threshold)
|
||||||
|
|
||||||
@mock.patch.object(
|
@unittest_mock.patch.object(
|
||||||
image_classifier.hyperparameters,
|
image_classifier.hyperparameters,
|
||||||
'HParams',
|
'HParams',
|
||||||
autospec=True,
|
autospec=True,
|
||||||
return_value=image_classifier.HParams(epochs=1))
|
return_value=image_classifier.HParams(epochs=1))
|
||||||
@mock.patch.object(
|
@unittest_mock.patch.object(
|
||||||
image_classifier.model_options,
|
image_classifier.model_options,
|
||||||
'ImageClassifierModelOptions',
|
'ImageClassifierModelOptions',
|
||||||
autospec=True,
|
autospec=True,
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Library to train model."""
|
"""Library to train model."""
|
||||||
|
|
||||||
|
import os
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
@ -78,11 +79,24 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
||||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||||
label_smoothing=hparams.label_smoothing)
|
label_smoothing=hparams.label_smoothing)
|
||||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||||
callbacks = model_util.get_default_callbacks(export_dir=hparams.export_dir)
|
|
||||||
|
summary_dir = os.path.join(hparams.export_dir, 'summaries')
|
||||||
|
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||||
|
# Save checkpoint every 5 epochs.
|
||||||
|
checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint')
|
||||||
|
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||||
|
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||||
|
save_weights_only=True,
|
||||||
|
period=5)
|
||||||
|
|
||||||
|
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||||
|
if latest_checkpoint:
|
||||||
|
print(f'Resuming from {latest_checkpoint}')
|
||||||
|
model.load_weights(latest_checkpoint)
|
||||||
|
|
||||||
# Train the model.
|
# Train the model.
|
||||||
return model.fit(
|
return model.fit(
|
||||||
x=train_ds,
|
x=train_ds,
|
||||||
epochs=hparams.epochs,
|
epochs=hparams.epochs,
|
||||||
validation_data=validation_ds,
|
validation_data=validation_ds,
|
||||||
callbacks=callbacks)
|
callbacks=[summary_callback, checkpoint_callback])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user