Internal model maker change.
PiperOrigin-RevId: 504472342
This commit is contained in:
parent
5dc81c4c27
commit
afb0182935
|
@ -37,7 +37,7 @@ class Classifier(custom_model.CustomModel):
|
||||||
label_names: A list of label names for the classes.
|
label_names: A list of label names for the classes.
|
||||||
shuffle: Whether the dataset should be shuffled.
|
shuffle: Whether the dataset should be shuffled.
|
||||||
"""
|
"""
|
||||||
super(Classifier, self).__init__(model_spec, shuffle)
|
super().__init__(model_spec, shuffle)
|
||||||
self._label_names = label_names
|
self._label_names = label_names
|
||||||
self._num_classes = len(label_names)
|
self._num_classes = len(label_names)
|
||||||
self._model: tf.keras.Model = None
|
self._model: tf.keras.Model = None
|
||||||
|
@ -48,7 +48,6 @@ class Classifier(custom_model.CustomModel):
|
||||||
self._hparams: hp.BaseHParams = None
|
self._hparams: hp.BaseHParams = None
|
||||||
self._history: tf.keras.callbacks.History = None
|
self._history: tf.keras.callbacks.History = None
|
||||||
|
|
||||||
# TODO: Integrate this into GestureRecognizer.
|
|
||||||
def _train_model(self,
|
def _train_model(self,
|
||||||
train_data: classification_ds.ClassificationDataset,
|
train_data: classification_ds.ClassificationDataset,
|
||||||
validation_data: classification_ds.ClassificationDataset,
|
validation_data: classification_ds.ClassificationDataset,
|
||||||
|
|
|
@ -53,6 +53,10 @@ class GestureRecognizer(classifier.Classifier):
|
||||||
model_spec=None, label_names=label_names, shuffle=hparams.shuffle)
|
model_spec=None, label_names=label_names, shuffle=hparams.shuffle)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
|
self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
|
||||||
|
self._metric_function = 'categorical_accuracy'
|
||||||
|
self._optimizer = 'adam'
|
||||||
|
self._callbacks = self._get_callbacks()
|
||||||
self._history = None
|
self._history = None
|
||||||
self.embedding_size = _EMBEDDING_SIZE
|
self.embedding_size = _EMBEDDING_SIZE
|
||||||
|
|
||||||
|
@ -71,7 +75,7 @@ class GestureRecognizer(classifier.Classifier):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_data: Training data.
|
train_data: Training data.
|
||||||
validation_data: Validation data. If None, skips validation process.
|
validation_data: Validation data.
|
||||||
options: options for creating and training gesture recognizer model.
|
options: options for creating and training gesture recognizer model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -87,49 +91,39 @@ class GestureRecognizer(classifier.Classifier):
|
||||||
label_names=train_data.label_names,
|
label_names=train_data.label_names,
|
||||||
model_options=options.model_options,
|
model_options=options.model_options,
|
||||||
hparams=options.hparams)
|
hparams=options.hparams)
|
||||||
|
gesture_recognizer._create_and_train_model(train_data, validation_data)
|
||||||
gesture_recognizer._create_model()
|
|
||||||
|
|
||||||
train_dataset = train_data.gen_tf_dataset(
|
|
||||||
batch_size=options.hparams.batch_size,
|
|
||||||
is_training=True,
|
|
||||||
shuffle=options.hparams.shuffle)
|
|
||||||
options.hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
|
||||||
steps_per_epoch=options.hparams.steps_per_epoch,
|
|
||||||
batch_size=options.hparams.batch_size,
|
|
||||||
train_data=train_data)
|
|
||||||
train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch)
|
|
||||||
|
|
||||||
validation_dataset = validation_data.gen_tf_dataset(
|
|
||||||
batch_size=options.hparams.batch_size, is_training=False)
|
|
||||||
|
|
||||||
tf.compat.v1.logging.info('Training the gesture recognizer model...')
|
|
||||||
gesture_recognizer._train(
|
|
||||||
train_data=train_dataset, validation_data=validation_dataset)
|
|
||||||
|
|
||||||
return gesture_recognizer
|
return gesture_recognizer
|
||||||
|
|
||||||
def _train(self, train_data: tf.data.Dataset,
|
def _create_and_train_model(
|
||||||
validation_data: tf.data.Dataset):
|
self,
|
||||||
"""Trains the model with input train_data.
|
train_data: classification_ds.ClassificationDataset,
|
||||||
|
validation_data: classification_ds.ClassificationDataset,
|
||||||
The training results are recorded by a self.History object returned by
|
):
|
||||||
tf.keras.Model.fit().
|
"""Creates and trains the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_data: Training data.
|
train_data: Training data.
|
||||||
validation_data: Validation data.
|
validation_data: Validation data.
|
||||||
"""
|
"""
|
||||||
|
self._create_model()
|
||||||
|
self._train_model(
|
||||||
|
train_data=train_data,
|
||||||
|
validation_data=validation_data,
|
||||||
|
checkpoint_path=self._get_checkpoint_path(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_callbacks(self) -> List[tf.keras.callbacks.Callback]:
|
||||||
|
"""Gets the list of callbacks to use in model training."""
|
||||||
hparams = self._hparams
|
hparams = self._hparams
|
||||||
|
|
||||||
scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch)
|
scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch)
|
||||||
scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
|
scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
|
||||||
|
|
||||||
job_dir = hparams.export_dir
|
job_dir = hparams.export_dir
|
||||||
checkpoint_path = os.path.join(job_dir, 'epoch_models')
|
|
||||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
os.path.join(self._get_checkpoint_path(), 'model-{epoch:04d}'),
|
||||||
save_weights_only=True)
|
save_weights_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
best_model_path = os.path.join(job_dir, 'best_model_weights')
|
best_model_path = os.path.join(job_dir, 'best_model_weights')
|
||||||
best_model_callback = tf.keras.callbacks.ModelCheckpoint(
|
best_model_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||||
|
@ -141,27 +135,15 @@ class GestureRecognizer(classifier.Classifier):
|
||||||
|
|
||||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
||||||
log_dir=os.path.join(job_dir, 'logs'))
|
log_dir=os.path.join(job_dir, 'logs'))
|
||||||
|
return [
|
||||||
|
checkpoint_callback,
|
||||||
|
best_model_callback,
|
||||||
|
scheduler_callback,
|
||||||
|
tensorboard_callback,
|
||||||
|
]
|
||||||
|
|
||||||
self._model.compile(
|
def _get_checkpoint_path(self) -> str:
|
||||||
optimizer='adam',
|
return os.path.join(self._hparams.export_dir, 'epoch_models')
|
||||||
loss=loss_functions.FocalLoss(gamma=self._hparams.gamma),
|
|
||||||
metrics=['categorical_accuracy'])
|
|
||||||
|
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
|
||||||
if latest_checkpoint:
|
|
||||||
print(f'Resuming from {latest_checkpoint}')
|
|
||||||
self._model.load_weights(latest_checkpoint)
|
|
||||||
|
|
||||||
self._history = self._model.fit(
|
|
||||||
x=train_data,
|
|
||||||
epochs=hparams.epochs,
|
|
||||||
validation_data=validation_data,
|
|
||||||
validation_freq=1,
|
|
||||||
callbacks=[
|
|
||||||
checkpoint_callback, best_model_callback, scheduler_callback,
|
|
||||||
tensorboard_callback
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_model(self):
|
def _create_model(self):
|
||||||
"""Creates the hand gesture recognizer model.
|
"""Creates the hand gesture recognizer model.
|
||||||
|
@ -172,7 +154,8 @@ class GestureRecognizer(classifier.Classifier):
|
||||||
shape=[self.embedding_size],
|
shape=[self.embedding_size],
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
dtype=tf.float32,
|
dtype=tf.float32,
|
||||||
name='hand_embedding')
|
name='hand_embedding',
|
||||||
|
)
|
||||||
x = inputs
|
x = inputs
|
||||||
dropout_rate = self._model_options.dropout_rate
|
dropout_rate = self._model_options.dropout_rate
|
||||||
for i, width in enumerate(self._model_options.layer_widths):
|
for i, width in enumerate(self._model_options.layer_widths):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user