diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 0908dddf5..abcfff835 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -37,7 +37,7 @@ class Classifier(custom_model.CustomModel): label_names: A list of label names for the classes. shuffle: Whether the dataset should be shuffled. """ - super(Classifier, self).__init__(model_spec, shuffle) + super().__init__(model_spec, shuffle) self._label_names = label_names self._num_classes = len(label_names) self._model: tf.keras.Model = None @@ -48,7 +48,6 @@ class Classifier(custom_model.CustomModel): self._hparams: hp.BaseHParams = None self._history: tf.keras.callbacks.History = None - # TODO: Integrate this into GestureRecognizer. def _train_model(self, train_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset, diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py index 556d2fcd7..b27f7161f 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -53,6 +53,10 @@ class GestureRecognizer(classifier.Classifier): model_spec=None, label_names=label_names, shuffle=hparams.shuffle) self._model_options = model_options self._hparams = hparams + self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma) + self._metric_function = 'categorical_accuracy' + self._optimizer = 'adam' + self._callbacks = self._get_callbacks() self._history = None self.embedding_size = _EMBEDDING_SIZE @@ -71,7 +75,7 @@ class GestureRecognizer(classifier.Classifier): Args: train_data: Training data. - validation_data: Validation data. If None, skips validation process. + validation_data: Validation data. options: options for creating and training gesture recognizer model. Returns: @@ -87,49 +91,39 @@ class GestureRecognizer(classifier.Classifier): label_names=train_data.label_names, model_options=options.model_options, hparams=options.hparams) - - gesture_recognizer._create_model() - - train_dataset = train_data.gen_tf_dataset( - batch_size=options.hparams.batch_size, - is_training=True, - shuffle=options.hparams.shuffle) - options.hparams.steps_per_epoch = model_util.get_steps_per_epoch( - steps_per_epoch=options.hparams.steps_per_epoch, - batch_size=options.hparams.batch_size, - train_data=train_data) - train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch) - - validation_dataset = validation_data.gen_tf_dataset( - batch_size=options.hparams.batch_size, is_training=False) - - tf.compat.v1.logging.info('Training the gesture recognizer model...') - gesture_recognizer._train( - train_data=train_dataset, validation_data=validation_dataset) - + gesture_recognizer._create_and_train_model(train_data, validation_data) return gesture_recognizer - def _train(self, train_data: tf.data.Dataset, - validation_data: tf.data.Dataset): - """Trains the model with input train_data. - - The training results are recorded by a self.History object returned by - tf.keras.Model.fit(). + def _create_and_train_model( + self, + train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset, + ): + """Creates and trains the model. Args: train_data: Training data. validation_data: Validation data. """ + self._create_model() + self._train_model( + train_data=train_data, + validation_data=validation_data, + checkpoint_path=self._get_checkpoint_path(), + ) + + def _get_callbacks(self) -> List[tf.keras.callbacks.Callback]: + """Gets the list of callbacks to use in model training.""" hparams = self._hparams scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch) scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler) job_dir = hparams.export_dir - checkpoint_path = os.path.join(job_dir, 'epoch_models') checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - os.path.join(checkpoint_path, 'model-{epoch:04d}'), - save_weights_only=True) + os.path.join(self._get_checkpoint_path(), 'model-{epoch:04d}'), + save_weights_only=True, + ) best_model_path = os.path.join(job_dir, 'best_model_weights') best_model_callback = tf.keras.callbacks.ModelCheckpoint( @@ -141,27 +135,15 @@ class GestureRecognizer(classifier.Classifier): tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=os.path.join(job_dir, 'logs')) + return [ + checkpoint_callback, + best_model_callback, + scheduler_callback, + tensorboard_callback, + ] - self._model.compile( - optimizer='adam', - loss=loss_functions.FocalLoss(gamma=self._hparams.gamma), - metrics=['categorical_accuracy']) - - latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path) - if latest_checkpoint: - print(f'Resuming from {latest_checkpoint}') - self._model.load_weights(latest_checkpoint) - - self._history = self._model.fit( - x=train_data, - epochs=hparams.epochs, - validation_data=validation_data, - validation_freq=1, - callbacks=[ - checkpoint_callback, best_model_callback, scheduler_callback, - tensorboard_callback - ], - ) + def _get_checkpoint_path(self) -> str: + return os.path.join(self._hparams.export_dir, 'epoch_models') def _create_model(self): """Creates the hand gesture recognizer model. @@ -172,7 +154,8 @@ class GestureRecognizer(classifier.Classifier): shape=[self.embedding_size], batch_size=None, dtype=tf.float32, - name='hand_embedding') + name='hand_embedding', + ) x = inputs dropout_rate = self._model_options.dropout_rate for i, width in enumerate(self._model_options.layer_widths):