Commit 4afa1639 authored by Remi Cresson's avatar Remi Cresson
Browse files

FIX: metrics bug (since TF 2.4, something like that)

1 merge request!10Update Dockerfile
Showing with 116 additions and 111 deletions
+116 -111
......@@ -146,117 +146,122 @@ def main(args):
expe_name += "_e{}".format(params.epochs)
expe_name += suffix
if True: # TODO: detete, just used for review
# Date tag
date_tag = time.strftime("%d-%m-%y-%H%M%S")
# adding the info to the SavedModel path
out_savedmodel = None if params.out_savedmodel is None else \
os.path.join(params.out_savedmodel, expe_name + date_tag)
# Scaling batch size and learning rate accordingly to number of workers
batch_size_train = params.batch_size_train * n_workers
batch_size_valid = params.batch_size_valid * n_workers
learning_rate = params.learning_rate * n_workers
logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)",
learning_rate, batch_size_train, n_workers)
# Datasets
tfrecord_train = TFRecords(params.training_record) if params.training_record else None
tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
# Model instantiation
model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
# TF.dataset-s instantiation
tf_ds_train = tfrecord_train.read(batch_size=batch_size_train,
target_keys=model.model_output_keys,
n_workers=n_workers,
shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None
tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid,
target_keys=model.model_output_keys,
n_workers=n_workers) for tfrecord in tfrecord_valid_array]
with strategy.scope():
# Creating the Keras network corresponding to the model
model.create_network()
# Metrics
metrics_list = [metrics.MeanSquaredError(), metrics.PSNR()]
if params.all_metrics:
metrics_list += [metrics.StructuralSimilarity(), metrics.SpectralAngle()] # A bit slow to compute
# Creating the model or loading it from checkpoints
logging.info("Loading model \"%s\"", params.model)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=model.get_loss(),
metrics={out_key: metrics_list for out_key in model.model_output_keys})
model.summary(strategy)
if params.plot_model:
model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy)
callbacks = []
# Define the checkpoint callback
if params.ckpt_dir:
if params.strategy == 'singlecpu':
logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
else:
backup_dir = os.path.join(params.ckpt_dir, params.model)
# Backup (deleted once the model is trained the specified number of epochs)
callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
# Persistent save (still here after the model is trained)
callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
# Define the Keras TensorBoard callback.
logdir = None
if params.logdir:
logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
profile_batch=params.profiling)
callbacks.append(tensorboard_callback)
# Define the previews callback
if params.previews:
# We run the preview on an arbitrary sample of the validation dataset
sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys)
previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
target_keys=model.model_output_keys)
callbacks.append(previews_callback)
# Validation on multiple datasets
if tf_ds_valid:
additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir)
callbacks.append(additional_validation_callback)
# Save best checkpoint only
if params.save_best:
callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True,
monitor=params.save_best_ref, mode='min'))
# Early stopping if the training stops improving
if params.early_stopping:
callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001,
patience=10, mode='min'))
# Training
model.fit(tf_ds_train,
epochs=params.epochs,
validation_data=tf_ds_valid[0] if tf_ds_valid else None,
callbacks=callbacks,
verbose=1 if params.verbose else 2)
# Multiworker training tries to save the model multiple times and this can create corrupted models
# Thus we save the model at the final path only for the 'chief' worker
if params.strategy != 'singlecpu':
if not _is_chief(strategy):
out_savedmodel = None
# Export SavedModel
if out_savedmodel and not params.save_best:
logging.info("Saving SavedModel in %s", out_savedmodel)
model.save(out_savedmodel)
# Date tag
date_tag = time.strftime("%d-%m-%y-%H%M%S")
# adding the info to the SavedModel path
out_savedmodel = None if params.out_savedmodel is None else \
os.path.join(params.out_savedmodel, expe_name + date_tag)
# Scaling batch size and learning rate accordingly to number of workers
batch_size_train = params.batch_size_train * n_workers
batch_size_valid = params.batch_size_valid * n_workers
learning_rate = params.learning_rate * n_workers
logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)",
learning_rate, batch_size_train, n_workers)
# Datasets
tfrecord_train = TFRecords(params.training_record) if params.training_record else None
tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
# Model instantiation
model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
# TF.dataset-s instantiation
tf_ds_train = tfrecord_train.read(batch_size=batch_size_train,
target_keys=model.model_output_keys,
n_workers=n_workers,
shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None
tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid,
target_keys=model.model_output_keys,
n_workers=n_workers) for tfrecord in tfrecord_valid_array]
with strategy.scope():
# Creating the Keras network corresponding to the model
model.create_network()
# Metrics
metrics_list = [metrics.MeanSquaredError, metrics.PSNR]
if params.all_metrics:
metrics_list += [metrics.StructuralSimilarity, metrics.SpectralAngle] # A bit slow to compute
# Creating the model or loading it from checkpoints
logging.info("Loading model \"%s\"", params.model)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=model.get_loss(),
metrics={
out_key: metric()
for out_key in model.model_output_keys
for metric in metrics_list
}
)
model.summary(strategy)
if params.plot_model:
model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy)
callbacks = []
# Define the checkpoint callback
if params.ckpt_dir:
if params.strategy == 'singlecpu':
logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
else:
backup_dir = os.path.join(params.ckpt_dir, params.model)
# Backup (deleted once the model is trained the specified number of epochs)
callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
# Persistent save (still here after the model is trained)
callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
# Define the Keras TensorBoard callback.
logdir = None
if params.logdir:
logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
profile_batch=params.profiling)
callbacks.append(tensorboard_callback)
# Define the previews callback
if params.previews:
# We run the preview on an arbitrary sample of the validation dataset
sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys)
previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
target_keys=model.model_output_keys)
callbacks.append(previews_callback)
# Validation on multiple datasets
if tf_ds_valid:
additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir)
callbacks.append(additional_validation_callback)
# Save best checkpoint only
if params.save_best:
callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True,
monitor=params.save_best_ref, mode='min'))
# Early stopping if the training stops improving
if params.early_stopping:
callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001,
patience=10, mode='min'))
# Training
model.fit(tf_ds_train,
epochs=params.epochs,
validation_data=tf_ds_valid[0] if tf_ds_valid else None,
callbacks=callbacks,
verbose=1 if params.verbose else 2)
# Multiworker training tries to save the model multiple times and this can create corrupted models
# Thus we save the model at the final path only for the 'chief' worker
if params.strategy != 'singlecpu':
if not _is_chief(strategy):
out_savedmodel = None
# Export SavedModel
if out_savedmodel and not params.save_best:
logging.info("Saving SavedModel in %s", out_savedmodel)
model.save(out_savedmodel)
if __name__ == "__main__":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment