Model Maker object detector change learning_rate_boundaries to learning_rate_epoch_boundaries.

PiperOrigin-RevId: 521024056
This commit is contained in:
MediaPipe Team 2023-03-31 15:17:13 -07:00 committed by Copybara-Service
parent 7f9fd4f154
commit d9f940f8b2
2 changed files with 47 additions and 26 deletions

View File

@ -29,9 +29,9 @@ class HParams(hp.BaseHParams):
epochs: Number of training iterations over the dataset. epochs: Number of training iterations over the dataset.
do_fine_tuning: If true, the base module is trained together with the do_fine_tuning: If true, the base module is trained together with the
classification layer on top. classification layer on top.
learning_rate_boundaries: List of epoch boundaries where learning_rate_epoch_boundaries: List of epoch boundaries where
learning_rate_boundaries[i] is the epoch where the learning rate will learning_rate_epoch_boundaries[i] is the epoch where the learning rate
decay to learning_rate * learning_rate_decay_multipliers[i]. will decay to learning_rate * learning_rate_decay_multipliers[i].
learning_rate_decay_multipliers: List of learning rate multipliers which learning_rate_decay_multipliers: List of learning rate multipliers which
calculates the learning rate at the ith boundary as learning_rate * calculates the learning rate at the ith boundary as learning_rate *
learning_rate_decay_multipliers[i]. learning_rate_decay_multipliers[i].
@ -43,35 +43,39 @@ class HParams(hp.BaseHParams):
epochs: int = 10 epochs: int = 10
# Parameters for learning rate decay # Parameters for learning rate decay
learning_rate_boundaries: List[int] = dataclasses.field( learning_rate_epoch_boundaries: List[int] = dataclasses.field(
default_factory=lambda: [5, 8] default_factory=lambda: []
) )
learning_rate_decay_multipliers: List[float] = dataclasses.field( learning_rate_decay_multipliers: List[float] = dataclasses.field(
default_factory=lambda: [0.1, 0.01] default_factory=lambda: []
) )
def __post_init__(self): def __post_init__(self):
# Validate stepwise learning rate parameters # Validate stepwise learning rate parameters
lr_boundary_len = len(self.learning_rate_boundaries) lr_boundary_len = len(self.learning_rate_epoch_boundaries)
lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers) lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers)
if lr_boundary_len != lr_decay_multipliers_len: if lr_boundary_len != lr_decay_multipliers_len:
raise ValueError( raise ValueError(
"Length of learning_rate_boundaries and ", "Length of learning_rate_epoch_boundaries and ",
"learning_rate_decay_multipliers do not match: ", "learning_rate_decay_multipliers do not match: ",
f"{lr_boundary_len}!={lr_decay_multipliers_len}", f"{lr_boundary_len}!={lr_decay_multipliers_len}",
) )
# Validate learning_rate_boundaries # Validate learning_rate_epoch_boundaries
if sorted(self.learning_rate_boundaries) != self.learning_rate_boundaries:
raise ValueError(
"learning_rate_boundaries is not in ascending order: ",
self.learning_rate_boundaries,
)
if ( if (
self.learning_rate_boundaries sorted(self.learning_rate_epoch_boundaries)
and self.learning_rate_boundaries[-1] > self.epochs != self.learning_rate_epoch_boundaries
): ):
raise ValueError( raise ValueError(
"Values in learning_rate_boundaries cannot be greater ", "than epochs" "learning_rate_epoch_boundaries is not in ascending order: ",
self.learning_rate_epoch_boundaries,
)
if (
self.learning_rate_epoch_boundaries
and self.learning_rate_epoch_boundaries[-1] > self.epochs
):
raise ValueError(
"Values in learning_rate_epoch_boundaries cannot be greater ",
"than epochs",
) )

View File

@ -57,7 +57,6 @@ class ObjectDetector(classifier.Classifier):
self._preprocessor = preprocessor.Preprocessor(model_spec) self._preprocessor = preprocessor.Preprocessor(model_spec)
self._hparams = hparams self._hparams = hparams
self._model_options = model_options self._model_options = model_options
self._optimizer = self._create_optimizer()
self._is_qat = False self._is_qat = False
@classmethod @classmethod
@ -104,6 +103,11 @@ class ObjectDetector(classifier.Classifier):
train_data: Training data. train_data: Training data.
validation_data: Validation data. validation_data: Validation data.
""" """
self._optimizer = self._create_optimizer(
model_util.get_steps_per_epoch(
self._hparams.steps_per_epoch,
)
)
self._create_model() self._create_model()
self._train_model( self._train_model(
train_data, validation_data, preprocessor=self._preprocessor train_data, validation_data, preprocessor=self._preprocessor
@ -333,21 +337,34 @@ class ObjectDetector(classifier.Classifier):
with open(metadata_file, 'w') as f: with open(metadata_file, 'w') as f:
f.write(metadata_json) f.write(metadata_json)
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: def _create_optimizer(
self, steps_per_epoch: int
) -> tf.keras.optimizers.Optimizer:
"""Creates an optimizer with learning rate schedule for regular training. """Creates an optimizer with learning rate schedule for regular training.
Uses Keras PiecewiseConstantDecay schedule by default. Uses Keras PiecewiseConstantDecay schedule by default.
Args:
steps_per_epoch: Steps per epoch to calculate the step boundaries from the
learning_rate_epoch_boundaries
Returns: Returns:
A tf.keras.optimizer.Optimizer for model training. A tf.keras.optimizer.Optimizer for model training.
""" """
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
if self._hparams.learning_rate_epoch_boundaries:
lr_values = [init_lr] + [ lr_values = [init_lr] + [
init_lr * m for m in self._hparams.learning_rate_decay_multipliers init_lr * m for m in self._hparams.learning_rate_decay_multipliers
] ]
learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay( lr_step_boundaries = [
self._hparams.learning_rate_boundaries, lr_values steps_per_epoch * epoch_boundary
for epoch_boundary in self._hparams.learning_rate_epoch_boundaries
]
learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
lr_step_boundaries, lr_values
) )
else:
learning_rate = init_lr
return tf.keras.optimizers.experimental.SGD( return tf.keras.optimizers.experimental.SGD(
learning_rate=learning_rate_fn, momentum=0.9 learning_rate=learning_rate, momentum=0.9
) )