Commit 16a6f1e2 authored by Cresson Remi's avatar Cresson Remi
Browse files

FIX: backup/restore callbacks

1 merge request!6Checkpoints callbacks fixes
Showing with 8 additions and 10 deletions
+8 -10
...@@ -25,7 +25,6 @@ import os ...@@ -25,7 +25,6 @@ import os
import shutil import shutil
import tensorflow as tf import tensorflow as tf
from tensorflow import keras from tensorflow import keras
from decloud.core import system
from decloud.models.utils import _is_chief from decloud.models.utils import _is_chief
# Callbacks being called at the end of each epoch during training # Callbacks being called at the end of each epoch during training
...@@ -45,7 +44,7 @@ class ArchiveCheckpoint(keras.callbacks.Callback): ...@@ -45,7 +44,7 @@ class ArchiveCheckpoint(keras.callbacks.Callback):
self.backup_dir = backup_dir self.backup_dir = backup_dir
self.strategy = strategy self.strategy = strategy
def on_epoch_end(self, epoch, logs=None): def on_epoch_begin(self, epoch, logs=None):
""" """
At the end of each epoch, we save the directory of BackupAndRestore to a different name for archiving At the end of each epoch, we save the directory of BackupAndRestore to a different name for archiving
""" """
...@@ -92,7 +91,7 @@ class AdditionalValidationSets(keras.callbacks.Callback): ...@@ -92,7 +91,7 @@ class AdditionalValidationSets(keras.callbacks.Callback):
for metric, result in zip(self.model.metrics_names, results): for metric, result in zip(self.model.metrics_names, results):
if self.logdir: if self.logdir:
writer = tf.summary.create_file_writer(system.pathify(self.logdir) + 'validation_{}'.format(i + 1)) writer = tf.summary.create_file_writer(os.path.join(self.logdir, 'validation_{}'.format(i + 1)))
with writer.as_default(): with writer.as_default():
tf.summary.scalar('epoch_' + metric, result, step=epoch) # tensorboard adds an 'epoch_' prefix tf.summary.scalar('epoch_' + metric, result, step=epoch) # tensorboard adds an 'epoch_' prefix
else: else:
......
...@@ -152,7 +152,7 @@ def main(args): ...@@ -152,7 +152,7 @@ def main(args):
# adding the info to the SavedModel path # adding the info to the SavedModel path
out_savedmodel = None if params.out_savedmodel is None else \ out_savedmodel = None if params.out_savedmodel is None else \
system.pathify(params.out_savedmodel) + expe_name + date_tag os.path.join(params.out_savedmodel, expe_name + date_tag)
# Scaling batch size and learning rate accordingly to number of workers # Scaling batch size and learning rate accordingly to number of workers
batch_size_train = params.batch_size_train * n_workers batch_size_train = params.batch_size_train * n_workers
...@@ -203,17 +203,16 @@ def main(args): ...@@ -203,17 +203,16 @@ def main(args):
if params.strategy == 'singlecpu': if params.strategy == 'singlecpu':
logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints') logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
else: else:
# Create a backup backup_dir = os.path.join(params.ckpt_dir, params.model)
backup_dir = system.pathify(params.ckpt_dir) + params.model # Backup (deleted once the model is trained the specified number of epochs)
callbacks.append(keras.callbacks.experimental.BackupAndRestore(backup_dir=backup_dir)) callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
# Persistent save (still here after the model is trained)
# Save the checkpoint to a persistent location
callbacks.append(ArchiveCheckpoint(backup_dir, strategy)) callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
# Define the Keras TensorBoard callback. # Define the Keras TensorBoard callback.
logdir = None logdir = None
if params.logdir: if params.logdir:
logdir = system.pathify(params.logdir) + "{}_{}".format(date_tag, expe_name) logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir, tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
profile_batch=params.profiling) profile_batch=params.profiling)
callbacks.append(tensorboard_callback) callbacks.append(tensorboard_callback)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment