Internal model maker change.

PiperOrigin-RevId: 504472342
This commit is contained in:
MediaPipe Team 2023-01-24 23:14:13 -08:00 committed by Copybara-Service
parent 5dc81c4c27
commit afb0182935
2 changed files with 35 additions and 53 deletions

View File

@ -37,7 +37,7 @@ class Classifier(custom_model.CustomModel):
label_names: A list of label names for the classes. label_names: A list of label names for the classes.
shuffle: Whether the dataset should be shuffled. shuffle: Whether the dataset should be shuffled.
""" """
super(Classifier, self).__init__(model_spec, shuffle) super().__init__(model_spec, shuffle)
self._label_names = label_names self._label_names = label_names
self._num_classes = len(label_names) self._num_classes = len(label_names)
self._model: tf.keras.Model = None self._model: tf.keras.Model = None
@ -48,7 +48,6 @@ class Classifier(custom_model.CustomModel):
self._hparams: hp.BaseHParams = None self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None self._history: tf.keras.callbacks.History = None
# TODO: Integrate this into GestureRecognizer.
def _train_model(self, def _train_model(self,
train_data: classification_ds.ClassificationDataset, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset,

View File

@ -53,6 +53,10 @@ class GestureRecognizer(classifier.Classifier):
model_spec=None, label_names=label_names, shuffle=hparams.shuffle) model_spec=None, label_names=label_names, shuffle=hparams.shuffle)
self._model_options = model_options self._model_options = model_options
self._hparams = hparams self._hparams = hparams
self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
self._metric_function = 'categorical_accuracy'
self._optimizer = 'adam'
self._callbacks = self._get_callbacks()
self._history = None self._history = None
self.embedding_size = _EMBEDDING_SIZE self.embedding_size = _EMBEDDING_SIZE
@ -71,7 +75,7 @@ class GestureRecognizer(classifier.Classifier):
Args: Args:
train_data: Training data. train_data: Training data.
validation_data: Validation data. If None, skips validation process. validation_data: Validation data.
options: options for creating and training gesture recognizer model. options: options for creating and training gesture recognizer model.
Returns: Returns:
@ -87,49 +91,39 @@ class GestureRecognizer(classifier.Classifier):
label_names=train_data.label_names, label_names=train_data.label_names,
model_options=options.model_options, model_options=options.model_options,
hparams=options.hparams) hparams=options.hparams)
gesture_recognizer._create_and_train_model(train_data, validation_data)
gesture_recognizer._create_model()
train_dataset = train_data.gen_tf_dataset(
batch_size=options.hparams.batch_size,
is_training=True,
shuffle=options.hparams.shuffle)
options.hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=options.hparams.steps_per_epoch,
batch_size=options.hparams.batch_size,
train_data=train_data)
train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch)
validation_dataset = validation_data.gen_tf_dataset(
batch_size=options.hparams.batch_size, is_training=False)
tf.compat.v1.logging.info('Training the gesture recognizer model...')
gesture_recognizer._train(
train_data=train_dataset, validation_data=validation_dataset)
return gesture_recognizer return gesture_recognizer
def _train(self, train_data: tf.data.Dataset, def _create_and_train_model(
validation_data: tf.data.Dataset): self,
"""Trains the model with input train_data. train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
The training results are recorded by a self.History object returned by ):
tf.keras.Model.fit(). """Creates and trains the model.
Args: Args:
train_data: Training data. train_data: Training data.
validation_data: Validation data. validation_data: Validation data.
""" """
self._create_model()
self._train_model(
train_data=train_data,
validation_data=validation_data,
checkpoint_path=self._get_checkpoint_path(),
)
def _get_callbacks(self) -> List[tf.keras.callbacks.Callback]:
"""Gets the list of callbacks to use in model training."""
hparams = self._hparams hparams = self._hparams
scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch) scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch)
scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler) scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
job_dir = hparams.export_dir job_dir = hparams.export_dir
checkpoint_path = os.path.join(job_dir, 'epoch_models')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(checkpoint_path, 'model-{epoch:04d}'), os.path.join(self._get_checkpoint_path(), 'model-{epoch:04d}'),
save_weights_only=True) save_weights_only=True,
)
best_model_path = os.path.join(job_dir, 'best_model_weights') best_model_path = os.path.join(job_dir, 'best_model_weights')
best_model_callback = tf.keras.callbacks.ModelCheckpoint( best_model_callback = tf.keras.callbacks.ModelCheckpoint(
@ -141,27 +135,15 @@ class GestureRecognizer(classifier.Classifier):
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=os.path.join(job_dir, 'logs')) log_dir=os.path.join(job_dir, 'logs'))
return [
checkpoint_callback,
best_model_callback,
scheduler_callback,
tensorboard_callback,
]
self._model.compile( def _get_checkpoint_path(self) -> str:
optimizer='adam', return os.path.join(self._hparams.export_dir, 'epoch_models')
loss=loss_functions.FocalLoss(gamma=self._hparams.gamma),
metrics=['categorical_accuracy'])
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
if latest_checkpoint:
print(f'Resuming from {latest_checkpoint}')
self._model.load_weights(latest_checkpoint)
self._history = self._model.fit(
x=train_data,
epochs=hparams.epochs,
validation_data=validation_data,
validation_freq=1,
callbacks=[
checkpoint_callback, best_model_callback, scheduler_callback,
tensorboard_callback
],
)
def _create_model(self): def _create_model(self):
"""Creates the hand gesture recognizer model. """Creates the hand gesture recognizer model.
@ -172,7 +154,8 @@ class GestureRecognizer(classifier.Classifier):
shape=[self.embedding_size], shape=[self.embedding_size],
batch_size=None, batch_size=None,
dtype=tf.float32, dtype=tf.float32,
name='hand_embedding') name='hand_embedding',
)
x = inputs x = inputs
dropout_rate = self._model_options.dropout_rate dropout_rate = self._model_options.dropout_rate
for i, width in enumerate(self._model_options.layer_widths): for i, width in enumerate(self._model_options.layer_widths):