Model Maker object detector change learning_rate_boundaries to learning_rate_epoch_boundaries.

PiperOrigin-RevId: 521024056
This commit is contained in:
MediaPipe Team 2023-03-31 15:17:13 -07:00 committed by Copybara-Service
parent 7f9fd4f154
commit d9f940f8b2
2 changed files with 47 additions and 26 deletions

View File

@ -29,9 +29,9 @@ class HParams(hp.BaseHParams):
epochs: Number of training iterations over the dataset.
do_fine_tuning: If true, the base module is trained together with the
classification layer on top.
learning_rate_boundaries: List of epoch boundaries where
learning_rate_boundaries[i] is the epoch where the learning rate will
decay to learning_rate * learning_rate_decay_multipliers[i].
learning_rate_epoch_boundaries: List of epoch boundaries where
learning_rate_epoch_boundaries[i] is the epoch where the learning rate
will decay to learning_rate * learning_rate_decay_multipliers[i].
learning_rate_decay_multipliers: List of learning rate multipliers which
calculates the learning rate at the ith boundary as learning_rate *
learning_rate_decay_multipliers[i].
@ -43,35 +43,39 @@ class HParams(hp.BaseHParams):
epochs: int = 10
# Parameters for learning rate decay
learning_rate_boundaries: List[int] = dataclasses.field(
default_factory=lambda: [5, 8]
learning_rate_epoch_boundaries: List[int] = dataclasses.field(
default_factory=lambda: []
)
learning_rate_decay_multipliers: List[float] = dataclasses.field(
default_factory=lambda: [0.1, 0.01]
default_factory=lambda: []
)
def __post_init__(self):
# Validate stepwise learning rate parameters
lr_boundary_len = len(self.learning_rate_boundaries)
lr_boundary_len = len(self.learning_rate_epoch_boundaries)
lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers)
if lr_boundary_len != lr_decay_multipliers_len:
raise ValueError(
"Length of learning_rate_boundaries and ",
"Length of learning_rate_epoch_boundaries and ",
"learning_rate_decay_multipliers do not match: ",
f"{lr_boundary_len}!={lr_decay_multipliers_len}",
)
# Validate learning_rate_boundaries
if sorted(self.learning_rate_boundaries) != self.learning_rate_boundaries:
raise ValueError(
"learning_rate_boundaries is not in ascending order: ",
self.learning_rate_boundaries,
)
# Validate learning_rate_epoch_boundaries
if (
self.learning_rate_boundaries
and self.learning_rate_boundaries[-1] > self.epochs
sorted(self.learning_rate_epoch_boundaries)
!= self.learning_rate_epoch_boundaries
):
raise ValueError(
"Values in learning_rate_boundaries cannot be greater ", "than epochs"
"learning_rate_epoch_boundaries is not in ascending order: ",
self.learning_rate_epoch_boundaries,
)
if (
self.learning_rate_epoch_boundaries
and self.learning_rate_epoch_boundaries[-1] > self.epochs
):
raise ValueError(
"Values in learning_rate_epoch_boundaries cannot be greater ",
"than epochs",
)

View File

@ -57,7 +57,6 @@ class ObjectDetector(classifier.Classifier):
self._preprocessor = preprocessor.Preprocessor(model_spec)
self._hparams = hparams
self._model_options = model_options
self._optimizer = self._create_optimizer()
self._is_qat = False
@classmethod
@ -104,6 +103,11 @@ class ObjectDetector(classifier.Classifier):
train_data: Training data.
validation_data: Validation data.
"""
self._optimizer = self._create_optimizer(
model_util.get_steps_per_epoch(
self._hparams.steps_per_epoch,
)
)
self._create_model()
self._train_model(
train_data, validation_data, preprocessor=self._preprocessor
@ -333,21 +337,34 @@ class ObjectDetector(classifier.Classifier):
with open(metadata_file, 'w') as f:
f.write(metadata_json)
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
def _create_optimizer(
self, steps_per_epoch: int
) -> tf.keras.optimizers.Optimizer:
"""Creates an optimizer with learning rate schedule for regular training.
Uses Keras PiecewiseConstantDecay schedule by default.
Args:
steps_per_epoch: Steps per epoch to calculate the step boundaries from the
learning_rate_epoch_boundaries
Returns:
A tf.keras.optimizer.Optimizer for model training.
"""
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
if self._hparams.learning_rate_epoch_boundaries:
lr_values = [init_lr] + [
init_lr * m for m in self._hparams.learning_rate_decay_multipliers
]
learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
self._hparams.learning_rate_boundaries, lr_values
lr_step_boundaries = [
steps_per_epoch * epoch_boundary
for epoch_boundary in self._hparams.learning_rate_epoch_boundaries
]
learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
lr_step_boundaries, lr_values
)
else:
learning_rate = init_lr
return tf.keras.optimizers.experimental.SGD(
learning_rate=learning_rate_fn, momentum=0.9
learning_rate=learning_rate, momentum=0.9
)