Commit 4afa1639 authored by Remi Cresson's avatar Remi Cresson
Browse files

FIX: metrics bug (since TF 2.4, something like that)

1 merge request!10Update Dockerfile
Showing with 116 additions and 111 deletions
+116 -111
...@@ -146,117 +146,122 @@ def main(args): ...@@ -146,117 +146,122 @@ def main(args):
expe_name += "_e{}".format(params.epochs) expe_name += "_e{}".format(params.epochs)
expe_name += suffix expe_name += suffix
if True: # TODO: detete, just used for review # Date tag
# Date tag date_tag = time.strftime("%d-%m-%y-%H%M%S")
date_tag = time.strftime("%d-%m-%y-%H%M%S")
# adding the info to the SavedModel path
# adding the info to the SavedModel path out_savedmodel = None if params.out_savedmodel is None else \
out_savedmodel = None if params.out_savedmodel is None else \ os.path.join(params.out_savedmodel, expe_name + date_tag)
os.path.join(params.out_savedmodel, expe_name + date_tag)
# Scaling batch size and learning rate accordingly to number of workers
# Scaling batch size and learning rate accordingly to number of workers batch_size_train = params.batch_size_train * n_workers
batch_size_train = params.batch_size_train * n_workers batch_size_valid = params.batch_size_valid * n_workers
batch_size_valid = params.batch_size_valid * n_workers learning_rate = params.learning_rate * n_workers
learning_rate = params.learning_rate * n_workers
logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)",
logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)", learning_rate, batch_size_train, n_workers)
learning_rate, batch_size_train, n_workers)
# Datasets
# Datasets tfrecord_train = TFRecords(params.training_record) if params.training_record else None
tfrecord_train = TFRecords(params.training_record) if params.training_record else None tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
# Model instantiation
# Model instantiation model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
# TF.dataset-s instantiation
# TF.dataset-s instantiation tf_ds_train = tfrecord_train.read(batch_size=batch_size_train,
tf_ds_train = tfrecord_train.read(batch_size=batch_size_train, target_keys=model.model_output_keys,
target_keys=model.model_output_keys, n_workers=n_workers,
n_workers=n_workers, shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None
shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid,
tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid, target_keys=model.model_output_keys,
target_keys=model.model_output_keys, n_workers=n_workers) for tfrecord in tfrecord_valid_array]
n_workers=n_workers) for tfrecord in tfrecord_valid_array]
with strategy.scope():
with strategy.scope(): # Creating the Keras network corresponding to the model
# Creating the Keras network corresponding to the model model.create_network()
model.create_network()
# Metrics
# Metrics metrics_list = [metrics.MeanSquaredError, metrics.PSNR]
metrics_list = [metrics.MeanSquaredError(), metrics.PSNR()] if params.all_metrics:
if params.all_metrics: metrics_list += [metrics.StructuralSimilarity, metrics.SpectralAngle] # A bit slow to compute
metrics_list += [metrics.StructuralSimilarity(), metrics.SpectralAngle()] # A bit slow to compute
# Creating the model or loading it from checkpoints
# Creating the model or loading it from checkpoints logging.info("Loading model \"%s\"", params.model)
logging.info("Loading model \"%s\"", params.model) model.compile(
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=model.get_loss(), loss=model.get_loss(),
metrics={out_key: metrics_list for out_key in model.model_output_keys}) metrics={
model.summary(strategy) out_key: metric()
for out_key in model.model_output_keys
if params.plot_model: for metric in metrics_list
model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy) }
)
callbacks = [] model.summary(strategy)
# Define the checkpoint callback
if params.ckpt_dir: if params.plot_model:
if params.strategy == 'singlecpu': model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy)
logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
else: callbacks = []
backup_dir = os.path.join(params.ckpt_dir, params.model) # Define the checkpoint callback
# Backup (deleted once the model is trained the specified number of epochs) if params.ckpt_dir:
callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir)) if params.strategy == 'singlecpu':
# Persistent save (still here after the model is trained) logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
callbacks.append(ArchiveCheckpoint(backup_dir, strategy)) else:
backup_dir = os.path.join(params.ckpt_dir, params.model)
# Define the Keras TensorBoard callback. # Backup (deleted once the model is trained the specified number of epochs)
logdir = None callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
if params.logdir: # Persistent save (still here after the model is trained)
logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}") callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
profile_batch=params.profiling) # Define the Keras TensorBoard callback.
callbacks.append(tensorboard_callback) logdir = None
if params.logdir:
# Define the previews callback logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
if params.previews: tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
# We run the preview on an arbitrary sample of the validation dataset profile_batch=params.profiling)
sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys) callbacks.append(tensorboard_callback)
previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
target_keys=model.model_output_keys) # Define the previews callback
callbacks.append(previews_callback) if params.previews:
# We run the preview on an arbitrary sample of the validation dataset
# Validation on multiple datasets sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys)
if tf_ds_valid: previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir) target_keys=model.model_output_keys)
callbacks.append(additional_validation_callback) callbacks.append(previews_callback)
# Save best checkpoint only # Validation on multiple datasets
if params.save_best: if tf_ds_valid:
callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True, additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir)
monitor=params.save_best_ref, mode='min')) callbacks.append(additional_validation_callback)
# Early stopping if the training stops improving # Save best checkpoint only
if params.early_stopping: if params.save_best:
callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001, callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True,
patience=10, mode='min')) monitor=params.save_best_ref, mode='min'))
# Training # Early stopping if the training stops improving
model.fit(tf_ds_train, if params.early_stopping:
epochs=params.epochs, callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001,
validation_data=tf_ds_valid[0] if tf_ds_valid else None, patience=10, mode='min'))
callbacks=callbacks,
verbose=1 if params.verbose else 2) # Training
model.fit(tf_ds_train,
# Multiworker training tries to save the model multiple times and this can create corrupted models epochs=params.epochs,
# Thus we save the model at the final path only for the 'chief' worker validation_data=tf_ds_valid[0] if tf_ds_valid else None,
if params.strategy != 'singlecpu': callbacks=callbacks,
if not _is_chief(strategy): verbose=1 if params.verbose else 2)
out_savedmodel = None
# Multiworker training tries to save the model multiple times and this can create corrupted models
# Export SavedModel # Thus we save the model at the final path only for the 'chief' worker
if out_savedmodel and not params.save_best: if params.strategy != 'singlecpu':
logging.info("Saving SavedModel in %s", out_savedmodel) if not _is_chief(strategy):
model.save(out_savedmodel) out_savedmodel = None
# Export SavedModel
if out_savedmodel and not params.save_best:
logging.info("Saving SavedModel in %s", out_savedmodel)
model.save(out_savedmodel)
if __name__ == "__main__": if __name__ == "__main__":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment