From 24d03451c701f522025652c1cc7685db51d61967 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 7 Nov 2022 22:34:41 -0800 Subject: [PATCH] Internal change - migration PiperOrigin-RevId: 486853689 --- .../python/core/utils/model_util.py | 17 +++++++++++++++-- .../train_image_classifier_lib.py | 18 +----------------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 02d4f5b1e..ada0a61e3 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -19,7 +19,7 @@ from __future__ import print_function import os import tempfile -from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union # Dependency imports @@ -34,6 +34,19 @@ DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 ESTIMITED_STEPS_PER_EPOCH = 1000 +def get_default_callbacks( + export_dir: str) -> Sequence[tf.keras.callbacks.Callback]: + """Gets default callbacks.""" + summary_dir = os.path.join(export_dir, 'summaries') + summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) + # Save checkpoint every 20 epochs. + + checkpoint_path = os.path.join(export_dir, 'checkpoint') + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + checkpoint_path, save_weights_only=True, period=20) + return [summary_callback, checkpoint_callback] + + def load_keras_model(model_path: str, compile_on_load: bool = False) -> tf.keras.Model: """Loads a tensorflow Keras model from file and returns the Keras model. @@ -174,7 +187,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): lambda: self.decay_schedule_fn(step), name=name) - def get_config(self) -> Dict[Text, Any]: + def get_config(self) -> Dict[str, Any]: return { 'initial_learning_rate': self.initial_learning_rate, 'decay_schedule_fn': self.decay_schedule_fn, diff --git a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py index e31225514..4adddefeb 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py +++ b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py @@ -13,9 +13,6 @@ # limitations under the License. """Library to train model.""" -import os -from typing import List - import tensorflow as tf from mediapipe.model_maker.python.core.utils import model_util @@ -49,19 +46,6 @@ def _create_optimizer(init_lr: float, decay_steps: int, return optimizer -def _get_default_callbacks( - export_dir: str) -> List[tf.keras.callbacks.Callback]: - """Gets default callbacks.""" - summary_dir = os.path.join(export_dir, 'summaries') - summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) - # Save checkpoint every 20 epochs. - - checkpoint_path = os.path.join(export_dir, 'checkpoint') - checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - checkpoint_path, save_weights_only=True, period=20) - return [summary_callback, checkpoint_callback] - - def train_model(model: tf.keras.Model, hparams: hp.HParams, train_ds: tf.data.Dataset, validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History: @@ -94,7 +78,7 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams, loss = tf.keras.losses.CategoricalCrossentropy( label_smoothing=hparams.label_smoothing) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) - callbacks = _get_default_callbacks(export_dir=hparams.export_dir) + callbacks = model_util.get_default_callbacks(export_dir=hparams.export_dir) # Train the model. return model.fit(