Internal model maker change.

PiperOrigin-RevId: 504472342
This commit is contained in:
MediaPipe Team 2023-01-24 23:14:13 -08:00 committed by Copybara-Service
parent 5dc81c4c27
commit afb0182935
2 changed files with 35 additions and 53 deletions

View File

@ -37,7 +37,7 @@ class Classifier(custom_model.CustomModel):
label_names: A list of label names for the classes.
shuffle: Whether the dataset should be shuffled.
"""
super(Classifier, self).__init__(model_spec, shuffle)
super().__init__(model_spec, shuffle)
self._label_names = label_names
self._num_classes = len(label_names)
self._model: tf.keras.Model = None
@ -48,7 +48,6 @@ class Classifier(custom_model.CustomModel):
self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None
# TODO: Integrate this into GestureRecognizer.
def _train_model(self,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,

View File

@ -53,6 +53,10 @@ class GestureRecognizer(classifier.Classifier):
model_spec=None, label_names=label_names, shuffle=hparams.shuffle)
self._model_options = model_options
self._hparams = hparams
self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
self._metric_function = 'categorical_accuracy'
self._optimizer = 'adam'
self._callbacks = self._get_callbacks()
self._history = None
self.embedding_size = _EMBEDDING_SIZE
@ -71,7 +75,7 @@ class GestureRecognizer(classifier.Classifier):
Args:
train_data: Training data.
validation_data: Validation data. If None, skips validation process.
validation_data: Validation data.
options: options for creating and training gesture recognizer model.
Returns:
@ -87,49 +91,39 @@ class GestureRecognizer(classifier.Classifier):
label_names=train_data.label_names,
model_options=options.model_options,
hparams=options.hparams)
gesture_recognizer._create_model()
train_dataset = train_data.gen_tf_dataset(
batch_size=options.hparams.batch_size,
is_training=True,
shuffle=options.hparams.shuffle)
options.hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=options.hparams.steps_per_epoch,
batch_size=options.hparams.batch_size,
train_data=train_data)
train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch)
validation_dataset = validation_data.gen_tf_dataset(
batch_size=options.hparams.batch_size, is_training=False)
tf.compat.v1.logging.info('Training the gesture recognizer model...')
gesture_recognizer._train(
train_data=train_dataset, validation_data=validation_dataset)
gesture_recognizer._create_and_train_model(train_data, validation_data)
return gesture_recognizer
def _train(self, train_data: tf.data.Dataset,
validation_data: tf.data.Dataset):
"""Trains the model with input train_data.
The training results are recorded by a self.History object returned by
tf.keras.Model.fit().
def _create_and_train_model(
self,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
):
"""Creates and trains the model.
Args:
train_data: Training data.
validation_data: Validation data.
"""
self._create_model()
self._train_model(
train_data=train_data,
validation_data=validation_data,
checkpoint_path=self._get_checkpoint_path(),
)
def _get_callbacks(self) -> List[tf.keras.callbacks.Callback]:
"""Gets the list of callbacks to use in model training."""
hparams = self._hparams
scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch)
scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
job_dir = hparams.export_dir
checkpoint_path = os.path.join(job_dir, 'epoch_models')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True)
os.path.join(self._get_checkpoint_path(), 'model-{epoch:04d}'),
save_weights_only=True,
)
best_model_path = os.path.join(job_dir, 'best_model_weights')
best_model_callback = tf.keras.callbacks.ModelCheckpoint(
@ -141,27 +135,15 @@ class GestureRecognizer(classifier.Classifier):
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=os.path.join(job_dir, 'logs'))
return [
checkpoint_callback,
best_model_callback,
scheduler_callback,
tensorboard_callback,
]
self._model.compile(
optimizer='adam',
loss=loss_functions.FocalLoss(gamma=self._hparams.gamma),
metrics=['categorical_accuracy'])
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
if latest_checkpoint:
print(f'Resuming from {latest_checkpoint}')
self._model.load_weights(latest_checkpoint)
self._history = self._model.fit(
x=train_data,
epochs=hparams.epochs,
validation_data=validation_data,
validation_freq=1,
callbacks=[
checkpoint_callback, best_model_callback, scheduler_callback,
tensorboard_callback
],
)
def _get_checkpoint_path(self) -> str:
return os.path.join(self._hparams.export_dir, 'epoch_models')
def _create_model(self):
"""Creates the hand gesture recognizer model.
@ -172,7 +154,8 @@ class GestureRecognizer(classifier.Classifier):
shape=[self.embedding_size],
batch_size=None,
dtype=tf.float32,
name='hand_embedding')
name='hand_embedding',
)
x = inputs
dropout_rate = self._model_options.dropout_rate
for i, width in enumerate(self._model_options.layer_widths):