Internal Model Maker change.

PiperOrigin-RevId: 500758488
This commit is contained in:
MediaPipe Team 2023-01-09 11:02:48 -08:00 committed by Copybara-Service
parent 73f4636292
commit d40fa6b16d
6 changed files with 67 additions and 162 deletions

View File

@ -48,11 +48,12 @@ class Classifier(custom_model.CustomModel):
self._hparams: hp.BaseHParams = None self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None self._history: tf.keras.callbacks.History = None
# TODO: Integrate this into all Model Maker tasks. # TODO: Integrate this into GestureRecognizer.
def _train_model(self, def _train_model(self,
train_data: classification_ds.ClassificationDataset, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None): preprocessor: Optional[Callable[..., bool]] = None,
checkpoint_path: Optional[str] = None):
"""Trains the classifier model. """Trains the classifier model.
Compiles and fits the tf.keras `_model` and records the `_history`. Compiles and fits the tf.keras `_model` and records the `_history`.
@ -62,6 +63,9 @@ class Classifier(custom_model.CustomModel):
validation_data: Validation data. validation_data: Validation data.
preprocessor: An optional data preprocessor that can be used when preprocessor: An optional data preprocessor that can be used when
generating a tf.data.Dataset. generating a tf.data.Dataset.
checkpoint_path: An optional directory for the checkpoint file to support
continual training. If provided, loads model weights from the latest
checkpoint in the directory.
""" """
tf.compat.v1.logging.info('Training the models...') tf.compat.v1.logging.info('Training the models...')
if len(train_data) < self._hparams.batch_size: if len(train_data) < self._hparams.batch_size:
@ -88,6 +92,14 @@ class Classifier(custom_model.CustomModel):
optimizer=self._optimizer, optimizer=self._optimizer,
loss=self._loss_function, loss=self._loss_function,
metrics=[self._metric_function]) metrics=[self._metric_function])
latest_checkpoint = (
tf.train.latest_checkpoint(checkpoint_path)
if checkpoint_path else None)
if latest_checkpoint:
print(f'Resuming from {latest_checkpoint}')
self._model.load_weights(latest_checkpoint)
self._history = self._model.fit( self._history = self._model.fit(
x=train_dataset, x=train_dataset,
epochs=self._hparams.epochs, epochs=self._hparams.epochs,

View File

@ -42,7 +42,9 @@ def get_default_callbacks(
checkpoint_path = os.path.join(export_dir, 'checkpoint') checkpoint_path = os.path.join(export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True) os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True,
period=5)
return [summary_callback, checkpoint_callback] return [summary_callback, checkpoint_callback]

View File

@ -87,15 +87,6 @@ py_library(
], ],
) )
py_library(
name = "train_image_classifier_lib",
srcs = ["train_image_classifier_lib.py"],
deps = [
":hyperparameters",
"//mediapipe/model_maker/python/core/utils:model_util",
],
)
py_library( py_library(
name = "image_classifier", name = "image_classifier",
srcs = ["image_classifier.py"], srcs = ["image_classifier.py"],
@ -104,7 +95,6 @@ py_library(
":image_classifier_options", ":image_classifier_options",
":model_options", ":model_options",
":model_spec", ":model_spec",
":train_image_classifier_lib",
"//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",

View File

@ -35,4 +35,3 @@ del image_classifier
del image_classifier_options del image_classifier_options
del model_options del model_options
del model_spec del model_spec
del train_image_classifier_lib # pylint: disable=undefined-variable

View File

@ -28,7 +28,6 @@ from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
@ -57,6 +56,10 @@ class ImageClassifier(classifier.Classifier):
mean_rgb=self._model_spec.mean_rgb, mean_rgb=self._model_spec.mean_rgb,
stddev_rgb=self._model_spec.stddev_rgb, stddev_rgb=self._model_spec.stddev_rgb,
use_augmentation=hparams.do_data_augmentation) use_augmentation=hparams.do_data_augmentation)
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._loss_function = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=self._hparams.label_smoothing)
self._metric_function = 'accuracy'
self._history = None # Training history returned from `keras_model.fit`. self._history = None # Training history returned from `keras_model.fit`.
@classmethod @classmethod
@ -66,7 +69,7 @@ class ImageClassifier(classifier.Classifier):
validation_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset,
options: image_classifier_options.ImageClassifierOptions, options: image_classifier_options.ImageClassifierOptions,
) -> 'ImageClassifier': ) -> 'ImageClassifier':
"""Creates and trains an image classifier. """Creates and trains an ImageClassifier.
Loads data and trains the model based on data for image classification. If a Loads data and trains the model based on data for image classification. If a
checkpoint file exists in the {options.hparams.export_dir}/checkpoint/ checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
@ -93,58 +96,29 @@ class ImageClassifier(classifier.Classifier):
label_names=train_data.label_names, label_names=train_data.label_names,
hparams=options.hparams, hparams=options.hparams,
model_options=options.model_options) model_options=options.model_options)
image_classifier._create_and_train_model(train_data, validation_data)
image_classifier._create_model()
tf.compat.v1.logging.info('Training the models...')
image_classifier._train(
train_data=train_data, validation_data=validation_data)
return image_classifier return image_classifier
# TODO: Migrate to the shared training library of Model Maker. def _create_and_train_model(
def _train(self, train_data: classification_ds.ClassificationDataset, self, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset): validation_data: classification_ds.ClassificationDataset):
"""Trains the model with input train_data. """Creates and trains the model and optimizer.
The training results are recorded by a self._history object returned by
tf.keras.Model.fit().
Args: Args:
train_data: Training data. train_data: Training data.
validation_data: Validation data. validation_data: Validation data.
""" """
self._create_model()
tf.compat.v1.logging.info('Training the models...') self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
hparams = self._hparams steps_per_epoch=self._hparams.steps_per_epoch,
if len(train_data) < hparams.batch_size: batch_size=self._hparams.batch_size,
raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
'than batch_size (%d). To solve this problem, set '
'the batch_size smaller or increase the size of the '
'train_data.' % (len(train_data), hparams.batch_size))
train_dataset = train_data.gen_tf_dataset(
batch_size=hparams.batch_size,
is_training=True,
shuffle=self._shuffle,
preprocess=self._preprocess)
hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=hparams.steps_per_epoch,
batch_size=hparams.batch_size,
train_data=train_data) train_data=train_data)
train_dataset = train_dataset.take(count=hparams.steps_per_epoch) self._optimizer = self._create_optimizer()
self._train_model(
validation_dataset = validation_data.gen_tf_dataset( train_data=train_data,
batch_size=hparams.batch_size, validation_data=validation_data,
is_training=False, preprocessor=self._preprocess,
preprocess=self._preprocess) checkpoint_path=os.path.join(self._hparams.export_dir, 'checkpoint'))
# Train the model.
self._history = train_image_classifier_lib.train_model(
model=self._model,
hparams=hparams,
train_ds=train_dataset,
validation_ds=validation_dataset)
def _create_model(self): def _create_model(self):
"""Creates the classifier model from TFHub pretrained models.""" """Creates the classifier model from TFHub pretrained models."""
@ -198,3 +172,33 @@ class ImageClassifier(classifier.Classifier):
model_util.save_tflite(tflite_model_with_metadata, tflite_file) model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, 'w') as f: with open(metadata_file, 'w') as f:
f.write(metadata_json) f.write(metadata_json)
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Creates an optimizer with learning rate schedule.
Uses Keras CosineDecay schedule for the learning rate by default.
Returns:
A tf.keras.optimizers.Optimizer for model training.
"""
# Learning rate is linear to batch size.
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
# Get decay steps.
total_training_steps = self._hparams.steps_per_epoch * self._hparams.epochs
default_decay_steps = (
self._hparams.decay_samples // self._hparams.batch_size)
decay_steps = max(total_training_steps, default_decay_steps)
learning_rate_fn = tf.keras.experimental.CosineDecay(
initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
warmup_steps = self._hparams.warmup_epochs * self._hparams.steps_per_epoch
if warmup_steps:
learning_rate_fn = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=warmup_steps)
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
return optimizer

View File

@ -1,102 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to train model."""
import os
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
def _create_optimizer(init_lr: float, decay_steps: int,
warmup_steps: int) -> tf.keras.optimizers.Optimizer:
"""Creates an optimizer with learning rate schedule.
Uses Keras CosineDecay schedule for the learning rate by default.
Args:
init_lr: Initial learning rate.
decay_steps: Number of steps to decay over.
warmup_steps: Number of steps to do warmup for.
Returns:
A tf.keras.optimizers.Optimizer for model training.
"""
learning_rate_fn = tf.keras.experimental.CosineDecay(
initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
if warmup_steps:
learning_rate_fn = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=warmup_steps)
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
return optimizer
def train_model(model: tf.keras.Model, hparams: hp.HParams,
train_ds: tf.data.Dataset,
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
"""Trains model with the input data and hyperparameters.
Args:
model: Input tf.keras.Model.
hparams: Hyperparameters for training image classifier.
train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit().
validation_ds: tf.data.Dataset, validation data to be fed in
tf.keras.Model.fit().
Returns:
The tf.keras.callbacks.History object returned by tf.keras.Model.fit().
"""
# Learning rate is linear to batch size.
learning_rate = hparams.learning_rate * hparams.batch_size / 256
# Get decay steps.
# NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
total_training_steps = hparams.steps_per_epoch * hparams.epochs
default_decay_steps = hparams.decay_samples // hparams.batch_size
decay_steps = max(total_training_steps, default_decay_steps)
warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch
optimizer = _create_optimizer(
init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps)
loss = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=hparams.label_smoothing)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
summary_dir = os.path.join(hparams.export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 5 epochs.
checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True,
period=5)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
if latest_checkpoint:
print(f'Resuming from {latest_checkpoint}')
model.load_weights(latest_checkpoint)
# Train the model.
return model.fit(
x=train_ds,
epochs=hparams.epochs,
validation_data=validation_ds,
callbacks=[summary_callback, checkpoint_callback])