Commit 16a6f1e2 authored by Cresson Remi's avatar Cresson Remi
Browse files

FIX: backup/restore callbacks

1 merge request!6Checkpoints callbacks fixes
Showing with 8 additions and 10 deletions
+8 -10
......@@ -25,7 +25,6 @@ import os
import shutil
import tensorflow as tf
from tensorflow import keras
from decloud.core import system
from decloud.models.utils import _is_chief
# Callbacks being called at the end of each epoch during training
......@@ -45,7 +44,7 @@ class ArchiveCheckpoint(keras.callbacks.Callback):
self.backup_dir = backup_dir
self.strategy = strategy
def on_epoch_end(self, epoch, logs=None):
def on_epoch_begin(self, epoch, logs=None):
"""
At the end of each epoch, we save the directory of BackupAndRestore to a different name for archiving
"""
......@@ -92,7 +91,7 @@ class AdditionalValidationSets(keras.callbacks.Callback):
for metric, result in zip(self.model.metrics_names, results):
if self.logdir:
writer = tf.summary.create_file_writer(system.pathify(self.logdir) + 'validation_{}'.format(i + 1))
writer = tf.summary.create_file_writer(os.path.join(self.logdir, 'validation_{}'.format(i + 1)))
with writer.as_default():
tf.summary.scalar('epoch_' + metric, result, step=epoch) # tensorboard adds an 'epoch_' prefix
else:
......
......@@ -152,7 +152,7 @@ def main(args):
# adding the info to the SavedModel path
out_savedmodel = None if params.out_savedmodel is None else \
system.pathify(params.out_savedmodel) + expe_name + date_tag
os.path.join(params.out_savedmodel, expe_name + date_tag)
# Scaling batch size and learning rate accordingly to number of workers
batch_size_train = params.batch_size_train * n_workers
......@@ -203,17 +203,16 @@ def main(args):
if params.strategy == 'singlecpu':
logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
else:
# Create a backup
backup_dir = system.pathify(params.ckpt_dir) + params.model
callbacks.append(keras.callbacks.experimental.BackupAndRestore(backup_dir=backup_dir))
# Save the checkpoint to a persistent location
backup_dir = os.path.join(params.ckpt_dir, params.model)
# Backup (deleted once the model is trained the specified number of epochs)
callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
# Persistent save (still here after the model is trained)
callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
# Define the Keras TensorBoard callback.
logdir = None
if params.logdir:
logdir = system.pathify(params.logdir) + "{}_{}".format(date_tag, expe_name)
logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
profile_batch=params.profiling)
callbacks.append(tensorboard_callback)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment