Internal Model Maker change.
PiperOrigin-RevId: 500758488
This commit is contained in:
parent
73f4636292
commit
d40fa6b16d
|
@ -48,11 +48,12 @@ class Classifier(custom_model.CustomModel):
|
|||
self._hparams: hp.BaseHParams = None
|
||||
self._history: tf.keras.callbacks.History = None
|
||||
|
||||
# TODO: Integrate this into all Model Maker tasks.
|
||||
# TODO: Integrate this into GestureRecognizer.
|
||||
def _train_model(self,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
preprocessor: Optional[Callable[..., bool]] = None):
|
||||
preprocessor: Optional[Callable[..., bool]] = None,
|
||||
checkpoint_path: Optional[str] = None):
|
||||
"""Trains the classifier model.
|
||||
|
||||
Compiles and fits the tf.keras `_model` and records the `_history`.
|
||||
|
@ -62,6 +63,9 @@ class Classifier(custom_model.CustomModel):
|
|||
validation_data: Validation data.
|
||||
preprocessor: An optional data preprocessor that can be used when
|
||||
generating a tf.data.Dataset.
|
||||
checkpoint_path: An optional directory for the checkpoint file to support
|
||||
continual training. If provided, loads model weights from the latest
|
||||
checkpoint in the directory.
|
||||
"""
|
||||
tf.compat.v1.logging.info('Training the models...')
|
||||
if len(train_data) < self._hparams.batch_size:
|
||||
|
@ -88,6 +92,14 @@ class Classifier(custom_model.CustomModel):
|
|||
optimizer=self._optimizer,
|
||||
loss=self._loss_function,
|
||||
metrics=[self._metric_function])
|
||||
|
||||
latest_checkpoint = (
|
||||
tf.train.latest_checkpoint(checkpoint_path)
|
||||
if checkpoint_path else None)
|
||||
if latest_checkpoint:
|
||||
print(f'Resuming from {latest_checkpoint}')
|
||||
self._model.load_weights(latest_checkpoint)
|
||||
|
||||
self._history = self._model.fit(
|
||||
x=train_dataset,
|
||||
epochs=self._hparams.epochs,
|
||||
|
|
|
@ -42,7 +42,9 @@ def get_default_callbacks(
|
|||
|
||||
checkpoint_path = os.path.join(export_dir, 'checkpoint')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
checkpoint_path, save_weights_only=True)
|
||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||
save_weights_only=True,
|
||||
period=5)
|
||||
return [summary_callback, checkpoint_callback]
|
||||
|
||||
|
||||
|
|
|
@ -87,15 +87,6 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "train_image_classifier_lib",
|
||||
srcs = ["train_image_classifier_lib.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_classifier",
|
||||
srcs = ["image_classifier.py"],
|
||||
|
@ -104,7 +95,6 @@ py_library(
|
|||
":image_classifier_options",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":train_image_classifier_lib",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
|
|
|
@ -35,4 +35,3 @@ del image_classifier
|
|||
del image_classifier_options
|
||||
del model_options
|
||||
del model_spec
|
||||
del train_image_classifier_lib # pylint: disable=undefined-variable
|
||||
|
|
|
@ -28,7 +28,6 @@ from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
|||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
||||
|
@ -57,6 +56,10 @@ class ImageClassifier(classifier.Classifier):
|
|||
mean_rgb=self._model_spec.mean_rgb,
|
||||
stddev_rgb=self._model_spec.stddev_rgb,
|
||||
use_augmentation=hparams.do_data_augmentation)
|
||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||
self._loss_function = tf.keras.losses.CategoricalCrossentropy(
|
||||
label_smoothing=self._hparams.label_smoothing)
|
||||
self._metric_function = 'accuracy'
|
||||
self._history = None # Training history returned from `keras_model.fit`.
|
||||
|
||||
@classmethod
|
||||
|
@ -66,7 +69,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
validation_data: classification_ds.ClassificationDataset,
|
||||
options: image_classifier_options.ImageClassifierOptions,
|
||||
) -> 'ImageClassifier':
|
||||
"""Creates and trains an image classifier.
|
||||
"""Creates and trains an ImageClassifier.
|
||||
|
||||
Loads data and trains the model based on data for image classification. If a
|
||||
checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
|
||||
|
@ -93,58 +96,29 @@ class ImageClassifier(classifier.Classifier):
|
|||
label_names=train_data.label_names,
|
||||
hparams=options.hparams,
|
||||
model_options=options.model_options)
|
||||
|
||||
image_classifier._create_model()
|
||||
|
||||
tf.compat.v1.logging.info('Training the models...')
|
||||
image_classifier._train(
|
||||
train_data=train_data, validation_data=validation_data)
|
||||
|
||||
image_classifier._create_and_train_model(train_data, validation_data)
|
||||
return image_classifier
|
||||
|
||||
# TODO: Migrate to the shared training library of Model Maker.
|
||||
def _train(self, train_data: classification_ds.ClassificationDataset,
|
||||
def _create_and_train_model(
|
||||
self, train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset):
|
||||
"""Trains the model with input train_data.
|
||||
|
||||
The training results are recorded by a self._history object returned by
|
||||
tf.keras.Model.fit().
|
||||
"""Creates and trains the model and optimizer.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
"""
|
||||
|
||||
tf.compat.v1.logging.info('Training the models...')
|
||||
hparams = self._hparams
|
||||
if len(train_data) < hparams.batch_size:
|
||||
raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
|
||||
'than batch_size (%d). To solve this problem, set '
|
||||
'the batch_size smaller or increase the size of the '
|
||||
'train_data.' % (len(train_data), hparams.batch_size))
|
||||
|
||||
train_dataset = train_data.gen_tf_dataset(
|
||||
batch_size=hparams.batch_size,
|
||||
is_training=True,
|
||||
shuffle=self._shuffle,
|
||||
preprocess=self._preprocess)
|
||||
hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=hparams.steps_per_epoch,
|
||||
batch_size=hparams.batch_size,
|
||||
self._create_model()
|
||||
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=self._hparams.steps_per_epoch,
|
||||
batch_size=self._hparams.batch_size,
|
||||
train_data=train_data)
|
||||
train_dataset = train_dataset.take(count=hparams.steps_per_epoch)
|
||||
|
||||
validation_dataset = validation_data.gen_tf_dataset(
|
||||
batch_size=hparams.batch_size,
|
||||
is_training=False,
|
||||
preprocess=self._preprocess)
|
||||
|
||||
# Train the model.
|
||||
self._history = train_image_classifier_lib.train_model(
|
||||
model=self._model,
|
||||
hparams=hparams,
|
||||
train_ds=train_dataset,
|
||||
validation_ds=validation_dataset)
|
||||
self._optimizer = self._create_optimizer()
|
||||
self._train_model(
|
||||
train_data=train_data,
|
||||
validation_data=validation_data,
|
||||
preprocessor=self._preprocess,
|
||||
checkpoint_path=os.path.join(self._hparams.export_dir, 'checkpoint'))
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates the classifier model from TFHub pretrained models."""
|
||||
|
@ -198,3 +172,33 @@ class ImageClassifier(classifier.Classifier):
|
|||
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||
with open(metadata_file, 'w') as f:
|
||||
f.write(metadata_json)
|
||||
|
||||
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
|
||||
"""Creates an optimizer with learning rate schedule.
|
||||
|
||||
Uses Keras CosineDecay schedule for the learning rate by default.
|
||||
|
||||
Returns:
|
||||
A tf.keras.optimizers.Optimizer for model training.
|
||||
"""
|
||||
# Learning rate is linear to batch size.
|
||||
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
|
||||
|
||||
# Get decay steps.
|
||||
total_training_steps = self._hparams.steps_per_epoch * self._hparams.epochs
|
||||
default_decay_steps = (
|
||||
self._hparams.decay_samples // self._hparams.batch_size)
|
||||
decay_steps = max(total_training_steps, default_decay_steps)
|
||||
|
||||
learning_rate_fn = tf.keras.experimental.CosineDecay(
|
||||
initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
|
||||
warmup_steps = self._hparams.warmup_epochs * self._hparams.steps_per_epoch
|
||||
if warmup_steps:
|
||||
learning_rate_fn = model_util.WarmUp(
|
||||
initial_learning_rate=init_lr,
|
||||
decay_schedule_fn=learning_rate_fn,
|
||||
warmup_steps=warmup_steps)
|
||||
optimizer = tf.keras.optimizers.RMSprop(
|
||||
learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
|
||||
|
||||
return optimizer
|
||||
|
|
|
@ -1,102 +0,0 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Library to train model."""
|
||||
|
||||
import os
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||
|
||||
|
||||
def _create_optimizer(init_lr: float, decay_steps: int,
|
||||
warmup_steps: int) -> tf.keras.optimizers.Optimizer:
|
||||
"""Creates an optimizer with learning rate schedule.
|
||||
|
||||
Uses Keras CosineDecay schedule for the learning rate by default.
|
||||
|
||||
Args:
|
||||
init_lr: Initial learning rate.
|
||||
decay_steps: Number of steps to decay over.
|
||||
warmup_steps: Number of steps to do warmup for.
|
||||
|
||||
Returns:
|
||||
A tf.keras.optimizers.Optimizer for model training.
|
||||
"""
|
||||
learning_rate_fn = tf.keras.experimental.CosineDecay(
|
||||
initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
|
||||
if warmup_steps:
|
||||
learning_rate_fn = model_util.WarmUp(
|
||||
initial_learning_rate=init_lr,
|
||||
decay_schedule_fn=learning_rate_fn,
|
||||
warmup_steps=warmup_steps)
|
||||
optimizer = tf.keras.optimizers.RMSprop(
|
||||
learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
||||
train_ds: tf.data.Dataset,
|
||||
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
|
||||
"""Trains model with the input data and hyperparameters.
|
||||
|
||||
Args:
|
||||
model: Input tf.keras.Model.
|
||||
hparams: Hyperparameters for training image classifier.
|
||||
train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit().
|
||||
validation_ds: tf.data.Dataset, validation data to be fed in
|
||||
tf.keras.Model.fit().
|
||||
|
||||
Returns:
|
||||
The tf.keras.callbacks.History object returned by tf.keras.Model.fit().
|
||||
"""
|
||||
|
||||
# Learning rate is linear to batch size.
|
||||
learning_rate = hparams.learning_rate * hparams.batch_size / 256
|
||||
|
||||
# Get decay steps.
|
||||
# NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
|
||||
total_training_steps = hparams.steps_per_epoch * hparams.epochs
|
||||
default_decay_steps = hparams.decay_samples // hparams.batch_size
|
||||
decay_steps = max(total_training_steps, default_decay_steps)
|
||||
|
||||
warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch
|
||||
optimizer = _create_optimizer(
|
||||
init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps)
|
||||
|
||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||
label_smoothing=hparams.label_smoothing)
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||
|
||||
summary_dir = os.path.join(hparams.export_dir, 'summaries')
|
||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||
# Save checkpoint every 5 epochs.
|
||||
checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||
save_weights_only=True,
|
||||
period=5)
|
||||
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||
if latest_checkpoint:
|
||||
print(f'Resuming from {latest_checkpoint}')
|
||||
model.load_weights(latest_checkpoint)
|
||||
|
||||
# Train the model.
|
||||
return model.fit(
|
||||
x=train_ds,
|
||||
epochs=hparams.epochs,
|
||||
validation_data=validation_ds,
|
||||
callbacks=[summary_callback, checkpoint_callback])
|
Loading…
Reference in New Issue
Block a user