Internal change - migration

PiperOrigin-RevId: 486853689
This commit is contained in:
MediaPipe Team 2022-11-07 22:34:41 -08:00 committed by Copybara-Service
parent 0a08e4768b
commit 24d03451c7
2 changed files with 16 additions and 19 deletions

View File

@ -19,7 +19,7 @@ from __future__ import print_function
import os import os
import tempfile import tempfile
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
# Dependency imports # Dependency imports
@ -34,6 +34,19 @@ DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
ESTIMITED_STEPS_PER_EPOCH = 1000 ESTIMITED_STEPS_PER_EPOCH = 1000
def get_default_callbacks(
export_dir: str) -> Sequence[tf.keras.callbacks.Callback]:
"""Gets default callbacks."""
summary_dir = os.path.join(export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 20 epochs.
checkpoint_path = os.path.join(export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True, period=20)
return [summary_callback, checkpoint_callback]
def load_keras_model(model_path: str, def load_keras_model(model_path: str,
compile_on_load: bool = False) -> tf.keras.Model: compile_on_load: bool = False) -> tf.keras.Model:
"""Loads a tensorflow Keras model from file and returns the Keras model. """Loads a tensorflow Keras model from file and returns the Keras model.
@ -174,7 +187,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
lambda: self.decay_schedule_fn(step), lambda: self.decay_schedule_fn(step),
name=name) name=name)
def get_config(self) -> Dict[Text, Any]: def get_config(self) -> Dict[str, Any]:
return { return {
'initial_learning_rate': self.initial_learning_rate, 'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn, 'decay_schedule_fn': self.decay_schedule_fn,

View File

@ -13,9 +13,6 @@
# limitations under the License. # limitations under the License.
"""Library to train model.""" """Library to train model."""
import os
from typing import List
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
@ -49,19 +46,6 @@ def _create_optimizer(init_lr: float, decay_steps: int,
return optimizer return optimizer
def _get_default_callbacks(
export_dir: str) -> List[tf.keras.callbacks.Callback]:
"""Gets default callbacks."""
summary_dir = os.path.join(export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 20 epochs.
checkpoint_path = os.path.join(export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True, period=20)
return [summary_callback, checkpoint_callback]
def train_model(model: tf.keras.Model, hparams: hp.HParams, def train_model(model: tf.keras.Model, hparams: hp.HParams,
train_ds: tf.data.Dataset, train_ds: tf.data.Dataset,
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History: validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
@ -94,7 +78,7 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
loss = tf.keras.losses.CategoricalCrossentropy( loss = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=hparams.label_smoothing) label_smoothing=hparams.label_smoothing)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
callbacks = _get_default_callbacks(export_dir=hparams.export_dir) callbacks = model_util.get_default_callbacks(export_dir=hparams.export_dir)
# Train the model. # Train the model.
return model.fit( return model.fit(