Commit 4afa1639 authored by Remi Cresson's avatar Remi Cresson
Browse files

FIX: metrics bug (since TF 2.4, something like that)

1 merge request!10Update Dockerfile
This commit is part of merge request !10. Comments created here will be created in the context of that merge request.
Showing with 116 additions and 111 deletions
+116 -111
...@@ -146,117 +146,122 @@ def main(args): ...@@ -146,117 +146,122 @@ def main(args):
expe_name += "_e{}".format(params.epochs) expe_name += "_e{}".format(params.epochs)
expe_name += suffix expe_name += suffix
if True: # TODO: detete, just used for review # Date tag
# Date tag date_tag = time.strftime("%d-%m-%y-%H%M%S")
date_tag = time.strftime("%d-%m-%y-%H%M%S")
# adding the info to the SavedModel path
# adding the info to the SavedModel path out_savedmodel = None if params.out_savedmodel is None else \
out_savedmodel = None if params.out_savedmodel is None else \ os.path.join(params.out_savedmodel, expe_name + date_tag)
os.path.join(params.out_savedmodel, expe_name + date_tag)
# Scaling batch size and learning rate accordingly to number of workers
# Scaling batch size and learning rate accordingly to number of workers batch_size_train = params.batch_size_train * n_workers
batch_size_train = params.batch_size_train * n_workers batch_size_valid = params.batch_size_valid * n_workers
batch_size_valid = params.batch_size_valid * n_workers learning_rate = params.learning_rate * n_workers
learning_rate = params.learning_rate * n_workers
logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)",
logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)", learning_rate, batch_size_train, n_workers)
learning_rate, batch_size_train, n_workers)
# Datasets
# Datasets tfrecord_train = TFRecords(params.training_record) if params.training_record else None
tfrecord_train = TFRecords(params.training_record) if params.training_record else None tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
# Model instantiation
# Model instantiation model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
# TF.dataset-s instantiation
# TF.dataset-s instantiation tf_ds_train = tfrecord_train.read(batch_size=batch_size_train,
tf_ds_train = tfrecord_train.read(batch_size=batch_size_train, target_keys=model.model_output_keys,
target_keys=model.model_output_keys, n_workers=n_workers,
n_workers=n_workers, shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None
shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid,
tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid, target_keys=model.model_output_keys,
target_keys=model.model_output_keys, n_workers=n_workers) for tfrecord in tfrecord_valid_array]
n_workers=n_workers) for tfrecord in tfrecord_valid_array]
with strategy.scope():
with strategy.scope(): # Creating the Keras network corresponding to the model
# Creating the Keras network corresponding to the model model.create_network()
model.create_network()
# Metrics
# Metrics metrics_list = [metrics.MeanSquaredError, metrics.PSNR]
metrics_list = [metrics.MeanSquaredError(), metrics.PSNR()] if params.all_metrics:
if params.all_metrics: metrics_list += [metrics.StructuralSimilarity, metrics.SpectralAngle] # A bit slow to compute
metrics_list += [metrics.StructuralSimilarity(), metrics.SpectralAngle()] # A bit slow to compute
# Creating the model or loading it from checkpoints
# Creating the model or loading it from checkpoints logging.info("Loading model \"%s\"", params.model)
logging.info("Loading model \"%s\"", params.model) model.compile(
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=model.get_loss(), loss=model.get_loss(),
metrics={out_key: metrics_list for out_key in model.model_output_keys}) metrics={
model.summary(strategy) out_key: metric()
for out_key in model.model_output_keys
if params.plot_model: for metric in metrics_list
model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy) }
)
callbacks = [] model.summary(strategy)
# Define the checkpoint callback
if params.ckpt_dir: if params.plot_model:
if params.strategy == 'singlecpu': model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy)
logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
else: callbacks = []
backup_dir = os.path.join(params.ckpt_dir, params.model) # Define the checkpoint callback
# Backup (deleted once the model is trained the specified number of epochs) if params.ckpt_dir:
callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir)) if params.strategy == 'singlecpu':
# Persistent save (still here after the model is trained) logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
callbacks.append(ArchiveCheckpoint(backup_dir, strategy)) else:
backup_dir = os.path.join(params.ckpt_dir, params.model)
# Define the Keras TensorBoard callback. # Backup (deleted once the model is trained the specified number of epochs)
logdir = None callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
if params.logdir: # Persistent save (still here after the model is trained)
logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}") callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
profile_batch=params.profiling) # Define the Keras TensorBoard callback.
callbacks.append(tensorboard_callback) logdir = None
if params.logdir:
# Define the previews callback logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
if params.previews: tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
# We run the preview on an arbitrary sample of the validation dataset profile_batch=params.profiling)
sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys) callbacks.append(tensorboard_callback)
previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
target_keys=model.model_output_keys) # Define the previews callback
callbacks.append(previews_callback) if params.previews:
# We run the preview on an arbitrary sample of the validation dataset
# Validation on multiple datasets sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys)
if tf_ds_valid: previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir) target_keys=model.model_output_keys)
callbacks.append(additional_validation_callback) callbacks.append(previews_callback)
# Save best checkpoint only # Validation on multiple datasets
if params.save_best: if tf_ds_valid:
callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True, additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir)
monitor=params.save_best_ref, mode='min')) callbacks.append(additional_validation_callback)
# Early stopping if the training stops improving # Save best checkpoint only
if params.early_stopping: if params.save_best:
callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001, callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True,
patience=10, mode='min')) monitor=params.save_best_ref, mode='min'))
# Training # Early stopping if the training stops improving
model.fit(tf_ds_train, if params.early_stopping:
epochs=params.epochs, callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001,
validation_data=tf_ds_valid[0] if tf_ds_valid else None, patience=10, mode='min'))
callbacks=callbacks,
verbose=1 if params.verbose else 2) # Training
model.fit(tf_ds_train,
# Multiworker training tries to save the model multiple times and this can create corrupted models epochs=params.epochs,
# Thus we save the model at the final path only for the 'chief' worker validation_data=tf_ds_valid[0] if tf_ds_valid else None,
if params.strategy != 'singlecpu': callbacks=callbacks,
if not _is_chief(strategy): verbose=1 if params.verbose else 2)
out_savedmodel = None
# Multiworker training tries to save the model multiple times and this can create corrupted models
# Export SavedModel # Thus we save the model at the final path only for the 'chief' worker
if out_savedmodel and not params.save_best: if params.strategy != 'singlecpu':
logging.info("Saving SavedModel in %s", out_savedmodel) if not _is_chief(strategy):
model.save(out_savedmodel) out_savedmodel = None
# Export SavedModel
if out_savedmodel and not params.save_best:
logging.info("Saving SavedModel in %s", out_savedmodel)
model.save(out_savedmodel)
if __name__ == "__main__": if __name__ == "__main__":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment