Commit 4afa1639 authored by Remi Cresson's avatar Remi Cresson
Browse files

FIX: metrics bug (since TF 2.4, something like that)

1 merge request!10Update Dockerfile
This commit is part of merge request !10. Comments created here will be created in the context of that merge request.
Showing with 116 additions and 111 deletions
+116 -111
......@@ -146,117 +146,122 @@ def main(args):
expe_name += "_e{}".format(params.epochs)
expe_name += suffix
if True: # TODO: detete, just used for review
# Date tag
date_tag = time.strftime("%d-%m-%y-%H%M%S")
# adding the info to the SavedModel path
out_savedmodel = None if params.out_savedmodel is None else \
os.path.join(params.out_savedmodel, expe_name + date_tag)
# Scaling batch size and learning rate accordingly to number of workers
batch_size_train = params.batch_size_train * n_workers
batch_size_valid = params.batch_size_valid * n_workers
learning_rate = params.learning_rate * n_workers
logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)",
learning_rate, batch_size_train, n_workers)
# Datasets
tfrecord_train = TFRecords(params.training_record) if params.training_record else None
tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
# Model instantiation
model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
# TF.dataset-s instantiation
tf_ds_train = tfrecord_train.read(batch_size=batch_size_train,
target_keys=model.model_output_keys,
n_workers=n_workers,
shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None
tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid,
target_keys=model.model_output_keys,
n_workers=n_workers) for tfrecord in tfrecord_valid_array]
with strategy.scope():
# Creating the Keras network corresponding to the model
model.create_network()
# Metrics
metrics_list = [metrics.MeanSquaredError(), metrics.PSNR()]
if params.all_metrics:
metrics_list += [metrics.StructuralSimilarity(), metrics.SpectralAngle()] # A bit slow to compute
# Creating the model or loading it from checkpoints
logging.info("Loading model \"%s\"", params.model)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=model.get_loss(),
metrics={out_key: metrics_list for out_key in model.model_output_keys})
model.summary(strategy)
if params.plot_model:
model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy)
callbacks = []
# Define the checkpoint callback
if params.ckpt_dir:
if params.strategy == 'singlecpu':
logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
else:
backup_dir = os.path.join(params.ckpt_dir, params.model)
# Backup (deleted once the model is trained the specified number of epochs)
callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
# Persistent save (still here after the model is trained)
callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
# Define the Keras TensorBoard callback.
logdir = None
if params.logdir:
logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
profile_batch=params.profiling)
callbacks.append(tensorboard_callback)
# Define the previews callback
if params.previews:
# We run the preview on an arbitrary sample of the validation dataset
sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys)
previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
target_keys=model.model_output_keys)
callbacks.append(previews_callback)
# Validation on multiple datasets
if tf_ds_valid:
additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir)
callbacks.append(additional_validation_callback)
# Save best checkpoint only
if params.save_best:
callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True,
monitor=params.save_best_ref, mode='min'))
# Early stopping if the training stops improving
if params.early_stopping:
callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001,
patience=10, mode='min'))
# Training
model.fit(tf_ds_train,
epochs=params.epochs,
validation_data=tf_ds_valid[0] if tf_ds_valid else None,
callbacks=callbacks,
verbose=1 if params.verbose else 2)
# Multiworker training tries to save the model multiple times and this can create corrupted models
# Thus we save the model at the final path only for the 'chief' worker
if params.strategy != 'singlecpu':
if not _is_chief(strategy):
out_savedmodel = None
# Export SavedModel
if out_savedmodel and not params.save_best:
logging.info("Saving SavedModel in %s", out_savedmodel)
model.save(out_savedmodel)
# Date tag
date_tag = time.strftime("%d-%m-%y-%H%M%S")
# adding the info to the SavedModel path
out_savedmodel = None if params.out_savedmodel is None else \
os.path.join(params.out_savedmodel, expe_name + date_tag)
# Scaling batch size and learning rate accordingly to number of workers
batch_size_train = params.batch_size_train * n_workers
batch_size_valid = params.batch_size_valid * n_workers
learning_rate = params.learning_rate * n_workers
logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)",
learning_rate, batch_size_train, n_workers)
# Datasets
tfrecord_train = TFRecords(params.training_record) if params.training_record else None
tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
# Model instantiation
model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
# TF.dataset-s instantiation
tf_ds_train = tfrecord_train.read(batch_size=batch_size_train,
target_keys=model.model_output_keys,
n_workers=n_workers,
shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None
tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid,
target_keys=model.model_output_keys,
n_workers=n_workers) for tfrecord in tfrecord_valid_array]
with strategy.scope():
# Creating the Keras network corresponding to the model
model.create_network()
# Metrics
metrics_list = [metrics.MeanSquaredError, metrics.PSNR]
if params.all_metrics:
metrics_list += [metrics.StructuralSimilarity, metrics.SpectralAngle] # A bit slow to compute
# Creating the model or loading it from checkpoints
logging.info("Loading model \"%s\"", params.model)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=model.get_loss(),
metrics={
out_key: metric()
for out_key in model.model_output_keys
for metric in metrics_list
}
)
model.summary(strategy)
if params.plot_model:
model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy)
callbacks = []
# Define the checkpoint callback
if params.ckpt_dir:
if params.strategy == 'singlecpu':
logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
else:
backup_dir = os.path.join(params.ckpt_dir, params.model)
# Backup (deleted once the model is trained the specified number of epochs)
callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
# Persistent save (still here after the model is trained)
callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
# Define the Keras TensorBoard callback.
logdir = None
if params.logdir:
logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
profile_batch=params.profiling)
callbacks.append(tensorboard_callback)
# Define the previews callback
if params.previews:
# We run the preview on an arbitrary sample of the validation dataset
sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys)
previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
target_keys=model.model_output_keys)
callbacks.append(previews_callback)
# Validation on multiple datasets
if tf_ds_valid:
additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir)
callbacks.append(additional_validation_callback)
# Save best checkpoint only
if params.save_best:
callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True,
monitor=params.save_best_ref, mode='min'))
# Early stopping if the training stops improving
if params.early_stopping:
callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001,
patience=10, mode='min'))
# Training
model.fit(tf_ds_train,
epochs=params.epochs,
validation_data=tf_ds_valid[0] if tf_ds_valid else None,
callbacks=callbacks,
verbose=1 if params.verbose else 2)
# Multiworker training tries to save the model multiple times and this can create corrupted models
# Thus we save the model at the final path only for the 'chief' worker
if params.strategy != 'singlecpu':
if not _is_chief(strategy):
out_savedmodel = None
# Export SavedModel
if out_savedmodel and not params.save_best:
logging.info("Saving SavedModel in %s", out_savedmodel)
model.save(out_savedmodel)
if __name__ == "__main__":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment