Adds an AverageWordVecModel.

PiperOrigin-RevId: 487104909
This commit is contained in:
MediaPipe Team 2022-11-08 18:27:53 -08:00 committed by Copybara-Service
parent c31aaa94a6
commit a5bcb97d88

View File

@ -39,11 +39,10 @@ def get_default_callbacks(
"""Gets default callbacks."""
summary_dir = os.path.join(export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 20 epochs.
checkpoint_path = os.path.join(export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True, period=20)
checkpoint_path, save_weights_only=True)
return [summary_callback, checkpoint_callback]