Internal model maker change.
PiperOrigin-RevId: 504472342
This commit is contained in:
parent
5dc81c4c27
commit
afb0182935
|
@ -37,7 +37,7 @@ class Classifier(custom_model.CustomModel):
|
|||
label_names: A list of label names for the classes.
|
||||
shuffle: Whether the dataset should be shuffled.
|
||||
"""
|
||||
super(Classifier, self).__init__(model_spec, shuffle)
|
||||
super().__init__(model_spec, shuffle)
|
||||
self._label_names = label_names
|
||||
self._num_classes = len(label_names)
|
||||
self._model: tf.keras.Model = None
|
||||
|
@ -48,7 +48,6 @@ class Classifier(custom_model.CustomModel):
|
|||
self._hparams: hp.BaseHParams = None
|
||||
self._history: tf.keras.callbacks.History = None
|
||||
|
||||
# TODO: Integrate this into GestureRecognizer.
|
||||
def _train_model(self,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
|
|
|
@ -53,6 +53,10 @@ class GestureRecognizer(classifier.Classifier):
|
|||
model_spec=None, label_names=label_names, shuffle=hparams.shuffle)
|
||||
self._model_options = model_options
|
||||
self._hparams = hparams
|
||||
self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
|
||||
self._metric_function = 'categorical_accuracy'
|
||||
self._optimizer = 'adam'
|
||||
self._callbacks = self._get_callbacks()
|
||||
self._history = None
|
||||
self.embedding_size = _EMBEDDING_SIZE
|
||||
|
||||
|
@ -71,7 +75,7 @@ class GestureRecognizer(classifier.Classifier):
|
|||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data. If None, skips validation process.
|
||||
validation_data: Validation data.
|
||||
options: options for creating and training gesture recognizer model.
|
||||
|
||||
Returns:
|
||||
|
@ -87,49 +91,39 @@ class GestureRecognizer(classifier.Classifier):
|
|||
label_names=train_data.label_names,
|
||||
model_options=options.model_options,
|
||||
hparams=options.hparams)
|
||||
|
||||
gesture_recognizer._create_model()
|
||||
|
||||
train_dataset = train_data.gen_tf_dataset(
|
||||
batch_size=options.hparams.batch_size,
|
||||
is_training=True,
|
||||
shuffle=options.hparams.shuffle)
|
||||
options.hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=options.hparams.steps_per_epoch,
|
||||
batch_size=options.hparams.batch_size,
|
||||
train_data=train_data)
|
||||
train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch)
|
||||
|
||||
validation_dataset = validation_data.gen_tf_dataset(
|
||||
batch_size=options.hparams.batch_size, is_training=False)
|
||||
|
||||
tf.compat.v1.logging.info('Training the gesture recognizer model...')
|
||||
gesture_recognizer._train(
|
||||
train_data=train_dataset, validation_data=validation_dataset)
|
||||
|
||||
gesture_recognizer._create_and_train_model(train_data, validation_data)
|
||||
return gesture_recognizer
|
||||
|
||||
def _train(self, train_data: tf.data.Dataset,
|
||||
validation_data: tf.data.Dataset):
|
||||
"""Trains the model with input train_data.
|
||||
|
||||
The training results are recorded by a self.History object returned by
|
||||
tf.keras.Model.fit().
|
||||
def _create_and_train_model(
|
||||
self,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
):
|
||||
"""Creates and trains the model.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
"""
|
||||
self._create_model()
|
||||
self._train_model(
|
||||
train_data=train_data,
|
||||
validation_data=validation_data,
|
||||
checkpoint_path=self._get_checkpoint_path(),
|
||||
)
|
||||
|
||||
def _get_callbacks(self) -> List[tf.keras.callbacks.Callback]:
|
||||
"""Gets the list of callbacks to use in model training."""
|
||||
hparams = self._hparams
|
||||
|
||||
scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch)
|
||||
scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
|
||||
|
||||
job_dir = hparams.export_dir
|
||||
checkpoint_path = os.path.join(job_dir, 'epoch_models')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||
save_weights_only=True)
|
||||
os.path.join(self._get_checkpoint_path(), 'model-{epoch:04d}'),
|
||||
save_weights_only=True,
|
||||
)
|
||||
|
||||
best_model_path = os.path.join(job_dir, 'best_model_weights')
|
||||
best_model_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
|
@ -141,27 +135,15 @@ class GestureRecognizer(classifier.Classifier):
|
|||
|
||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
||||
log_dir=os.path.join(job_dir, 'logs'))
|
||||
return [
|
||||
checkpoint_callback,
|
||||
best_model_callback,
|
||||
scheduler_callback,
|
||||
tensorboard_callback,
|
||||
]
|
||||
|
||||
self._model.compile(
|
||||
optimizer='adam',
|
||||
loss=loss_functions.FocalLoss(gamma=self._hparams.gamma),
|
||||
metrics=['categorical_accuracy'])
|
||||
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||
if latest_checkpoint:
|
||||
print(f'Resuming from {latest_checkpoint}')
|
||||
self._model.load_weights(latest_checkpoint)
|
||||
|
||||
self._history = self._model.fit(
|
||||
x=train_data,
|
||||
epochs=hparams.epochs,
|
||||
validation_data=validation_data,
|
||||
validation_freq=1,
|
||||
callbacks=[
|
||||
checkpoint_callback, best_model_callback, scheduler_callback,
|
||||
tensorboard_callback
|
||||
],
|
||||
)
|
||||
def _get_checkpoint_path(self) -> str:
|
||||
return os.path.join(self._hparams.export_dir, 'epoch_models')
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates the hand gesture recognizer model.
|
||||
|
@ -172,7 +154,8 @@ class GestureRecognizer(classifier.Classifier):
|
|||
shape=[self.embedding_size],
|
||||
batch_size=None,
|
||||
dtype=tf.float32,
|
||||
name='hand_embedding')
|
||||
name='hand_embedding',
|
||||
)
|
||||
x = inputs
|
||||
dropout_rate = self._model_options.dropout_rate
|
||||
for i, width in enumerate(self._model_options.layer_widths):
|
||||
|
|
Loading…
Reference in New Issue
Block a user