Internal change - migration
PiperOrigin-RevId: 486853689
This commit is contained in:
parent
0a08e4768b
commit
24d03451c7
|
@ -19,7 +19,7 @@ from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
# Dependency imports
|
# Dependency imports
|
||||||
|
|
||||||
|
@ -34,6 +34,19 @@ DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
||||||
ESTIMITED_STEPS_PER_EPOCH = 1000
|
ESTIMITED_STEPS_PER_EPOCH = 1000
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_callbacks(
|
||||||
|
export_dir: str) -> Sequence[tf.keras.callbacks.Callback]:
|
||||||
|
"""Gets default callbacks."""
|
||||||
|
summary_dir = os.path.join(export_dir, 'summaries')
|
||||||
|
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||||
|
# Save checkpoint every 20 epochs.
|
||||||
|
|
||||||
|
checkpoint_path = os.path.join(export_dir, 'checkpoint')
|
||||||
|
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||||
|
checkpoint_path, save_weights_only=True, period=20)
|
||||||
|
return [summary_callback, checkpoint_callback]
|
||||||
|
|
||||||
|
|
||||||
def load_keras_model(model_path: str,
|
def load_keras_model(model_path: str,
|
||||||
compile_on_load: bool = False) -> tf.keras.Model:
|
compile_on_load: bool = False) -> tf.keras.Model:
|
||||||
"""Loads a tensorflow Keras model from file and returns the Keras model.
|
"""Loads a tensorflow Keras model from file and returns the Keras model.
|
||||||
|
@ -174,7 +187,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
lambda: self.decay_schedule_fn(step),
|
lambda: self.decay_schedule_fn(step),
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
def get_config(self) -> Dict[Text, Any]:
|
def get_config(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
'initial_learning_rate': self.initial_learning_rate,
|
'initial_learning_rate': self.initial_learning_rate,
|
||||||
'decay_schedule_fn': self.decay_schedule_fn,
|
'decay_schedule_fn': self.decay_schedule_fn,
|
||||||
|
|
|
@ -13,9 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Library to train model."""
|
"""Library to train model."""
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
@ -49,19 +46,6 @@ def _create_optimizer(init_lr: float, decay_steps: int,
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def _get_default_callbacks(
|
|
||||||
export_dir: str) -> List[tf.keras.callbacks.Callback]:
|
|
||||||
"""Gets default callbacks."""
|
|
||||||
summary_dir = os.path.join(export_dir, 'summaries')
|
|
||||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
|
||||||
# Save checkpoint every 20 epochs.
|
|
||||||
|
|
||||||
checkpoint_path = os.path.join(export_dir, 'checkpoint')
|
|
||||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
|
||||||
checkpoint_path, save_weights_only=True, period=20)
|
|
||||||
return [summary_callback, checkpoint_callback]
|
|
||||||
|
|
||||||
|
|
||||||
def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
||||||
train_ds: tf.data.Dataset,
|
train_ds: tf.data.Dataset,
|
||||||
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
|
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
|
||||||
|
@ -94,7 +78,7 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
||||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||||
label_smoothing=hparams.label_smoothing)
|
label_smoothing=hparams.label_smoothing)
|
||||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||||
callbacks = _get_default_callbacks(export_dir=hparams.export_dir)
|
callbacks = model_util.get_default_callbacks(export_dir=hparams.export_dir)
|
||||||
|
|
||||||
# Train the model.
|
# Train the model.
|
||||||
return model.fit(
|
return model.fit(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user