Internal Model Maker change.

PiperOrigin-RevId: 500758488
This commit is contained in:
MediaPipe Team 2023-01-09 11:02:48 -08:00 committed by Copybara-Service
parent 73f4636292
commit d40fa6b16d
6 changed files with 67 additions and 162 deletions

View File

@ -48,11 +48,12 @@ class Classifier(custom_model.CustomModel):
self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None
# TODO: Integrate this into all Model Maker tasks.
# TODO: Integrate this into GestureRecognizer.
def _train_model(self,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None):
preprocessor: Optional[Callable[..., bool]] = None,
checkpoint_path: Optional[str] = None):
"""Trains the classifier model.
Compiles and fits the tf.keras `_model` and records the `_history`.
@ -62,6 +63,9 @@ class Classifier(custom_model.CustomModel):
validation_data: Validation data.
preprocessor: An optional data preprocessor that can be used when
generating a tf.data.Dataset.
checkpoint_path: An optional directory for the checkpoint file to support
continual training. If provided, loads model weights from the latest
checkpoint in the directory.
"""
tf.compat.v1.logging.info('Training the models...')
if len(train_data) < self._hparams.batch_size:
@ -88,6 +92,14 @@ class Classifier(custom_model.CustomModel):
optimizer=self._optimizer,
loss=self._loss_function,
metrics=[self._metric_function])
latest_checkpoint = (
tf.train.latest_checkpoint(checkpoint_path)
if checkpoint_path else None)
if latest_checkpoint:
print(f'Resuming from {latest_checkpoint}')
self._model.load_weights(latest_checkpoint)
self._history = self._model.fit(
x=train_dataset,
epochs=self._hparams.epochs,

View File

@ -42,7 +42,9 @@ def get_default_callbacks(
checkpoint_path = os.path.join(export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True)
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True,
period=5)
return [summary_callback, checkpoint_callback]

View File

@ -87,15 +87,6 @@ py_library(
],
)
py_library(
name = "train_image_classifier_lib",
srcs = ["train_image_classifier_lib.py"],
deps = [
":hyperparameters",
"//mediapipe/model_maker/python/core/utils:model_util",
],
)
py_library(
name = "image_classifier",
srcs = ["image_classifier.py"],
@ -104,7 +95,6 @@ py_library(
":image_classifier_options",
":model_options",
":model_spec",
":train_image_classifier_lib",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:model_util",

View File

@ -35,4 +35,3 @@ del image_classifier
del image_classifier_options
del model_options
del model_spec
del train_image_classifier_lib # pylint: disable=undefined-variable

View File

@ -28,7 +28,6 @@ from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
@ -57,6 +56,10 @@ class ImageClassifier(classifier.Classifier):
mean_rgb=self._model_spec.mean_rgb,
stddev_rgb=self._model_spec.stddev_rgb,
use_augmentation=hparams.do_data_augmentation)
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._loss_function = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=self._hparams.label_smoothing)
self._metric_function = 'accuracy'
self._history = None # Training history returned from `keras_model.fit`.
@classmethod
@ -66,7 +69,7 @@ class ImageClassifier(classifier.Classifier):
validation_data: classification_ds.ClassificationDataset,
options: image_classifier_options.ImageClassifierOptions,
) -> 'ImageClassifier':
"""Creates and trains an image classifier.
"""Creates and trains an ImageClassifier.
Loads data and trains the model based on data for image classification. If a
checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
@ -93,58 +96,29 @@ class ImageClassifier(classifier.Classifier):
label_names=train_data.label_names,
hparams=options.hparams,
model_options=options.model_options)
image_classifier._create_model()
tf.compat.v1.logging.info('Training the models...')
image_classifier._train(
train_data=train_data, validation_data=validation_data)
image_classifier._create_and_train_model(train_data, validation_data)
return image_classifier
# TODO: Migrate to the shared training library of Model Maker.
def _train(self, train_data: classification_ds.ClassificationDataset,
def _create_and_train_model(
self, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset):
"""Trains the model with input train_data.
The training results are recorded by a self._history object returned by
tf.keras.Model.fit().
"""Creates and trains the model and optimizer.
Args:
train_data: Training data.
validation_data: Validation data.
"""
tf.compat.v1.logging.info('Training the models...')
hparams = self._hparams
if len(train_data) < hparams.batch_size:
raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
'than batch_size (%d). To solve this problem, set '
'the batch_size smaller or increase the size of the '
'train_data.' % (len(train_data), hparams.batch_size))
train_dataset = train_data.gen_tf_dataset(
batch_size=hparams.batch_size,
is_training=True,
shuffle=self._shuffle,
preprocess=self._preprocess)
hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=hparams.steps_per_epoch,
batch_size=hparams.batch_size,
self._create_model()
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=self._hparams.steps_per_epoch,
batch_size=self._hparams.batch_size,
train_data=train_data)
train_dataset = train_dataset.take(count=hparams.steps_per_epoch)
validation_dataset = validation_data.gen_tf_dataset(
batch_size=hparams.batch_size,
is_training=False,
preprocess=self._preprocess)
# Train the model.
self._history = train_image_classifier_lib.train_model(
model=self._model,
hparams=hparams,
train_ds=train_dataset,
validation_ds=validation_dataset)
self._optimizer = self._create_optimizer()
self._train_model(
train_data=train_data,
validation_data=validation_data,
preprocessor=self._preprocess,
checkpoint_path=os.path.join(self._hparams.export_dir, 'checkpoint'))
def _create_model(self):
"""Creates the classifier model from TFHub pretrained models."""
@ -198,3 +172,33 @@ class ImageClassifier(classifier.Classifier):
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, 'w') as f:
f.write(metadata_json)
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Creates an optimizer with learning rate schedule.
Uses Keras CosineDecay schedule for the learning rate by default.
Returns:
A tf.keras.optimizers.Optimizer for model training.
"""
# Learning rate is linear to batch size.
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
# Get decay steps.
total_training_steps = self._hparams.steps_per_epoch * self._hparams.epochs
default_decay_steps = (
self._hparams.decay_samples // self._hparams.batch_size)
decay_steps = max(total_training_steps, default_decay_steps)
learning_rate_fn = tf.keras.experimental.CosineDecay(
initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
warmup_steps = self._hparams.warmup_epochs * self._hparams.steps_per_epoch
if warmup_steps:
learning_rate_fn = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=warmup_steps)
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
return optimizer

View File

@ -1,102 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to train model."""
import os
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
def _create_optimizer(init_lr: float, decay_steps: int,
warmup_steps: int) -> tf.keras.optimizers.Optimizer:
"""Creates an optimizer with learning rate schedule.
Uses Keras CosineDecay schedule for the learning rate by default.
Args:
init_lr: Initial learning rate.
decay_steps: Number of steps to decay over.
warmup_steps: Number of steps to do warmup for.
Returns:
A tf.keras.optimizers.Optimizer for model training.
"""
learning_rate_fn = tf.keras.experimental.CosineDecay(
initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
if warmup_steps:
learning_rate_fn = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=warmup_steps)
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
return optimizer
def train_model(model: tf.keras.Model, hparams: hp.HParams,
train_ds: tf.data.Dataset,
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
"""Trains model with the input data and hyperparameters.
Args:
model: Input tf.keras.Model.
hparams: Hyperparameters for training image classifier.
train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit().
validation_ds: tf.data.Dataset, validation data to be fed in
tf.keras.Model.fit().
Returns:
The tf.keras.callbacks.History object returned by tf.keras.Model.fit().
"""
# Learning rate is linear to batch size.
learning_rate = hparams.learning_rate * hparams.batch_size / 256
# Get decay steps.
# NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
total_training_steps = hparams.steps_per_epoch * hparams.epochs
default_decay_steps = hparams.decay_samples // hparams.batch_size
decay_steps = max(total_training_steps, default_decay_steps)
warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch
optimizer = _create_optimizer(
init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps)
loss = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=hparams.label_smoothing)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
summary_dir = os.path.join(hparams.export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 5 epochs.
checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True,
period=5)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
if latest_checkpoint:
print(f'Resuming from {latest_checkpoint}')
model.load_weights(latest_checkpoint)
# Train the model.
return model.fit(
x=train_ds,
epochs=hparams.epochs,
validation_data=validation_ds,
callbacks=[summary_callback, checkpoint_callback])