Adds an AverageWordVecModel.

PiperOrigin-RevId: 487104909
This commit is contained in:
MediaPipe Team 2022-11-08 18:27:53 -08:00 committed by Copybara-Service
parent c31aaa94a6
commit a5bcb97d88

View File

@ -39,11 +39,10 @@ def get_default_callbacks(
"""Gets default callbacks.""" """Gets default callbacks."""
summary_dir = os.path.join(export_dir, 'summaries') summary_dir = os.path.join(export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 20 epochs.
checkpoint_path = os.path.join(export_dir, 'checkpoint') checkpoint_path = os.path.join(export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True, period=20) checkpoint_path, save_weights_only=True)
return [summary_callback, checkpoint_callback] return [summary_callback, checkpoint_callback]