diff --git a/decloud/models/train_from_tfrecords.py b/decloud/models/train_from_tfrecords.py index eb58d4e449e841342867a547296a92d4cd4e11dc..75ef766905e5bb695107aab6a2b263d55bc1e662 100644 --- a/decloud/models/train_from_tfrecords.py +++ b/decloud/models/train_from_tfrecords.py @@ -146,117 +146,122 @@ def main(args): expe_name += "_e{}".format(params.epochs) expe_name += suffix - if True: # TODO: detete, just used for review - # Date tag - date_tag = time.strftime("%d-%m-%y-%H%M%S") - - # adding the info to the SavedModel path - out_savedmodel = None if params.out_savedmodel is None else \ - os.path.join(params.out_savedmodel, expe_name + date_tag) - - # Scaling batch size and learning rate accordingly to number of workers - batch_size_train = params.batch_size_train * n_workers - batch_size_valid = params.batch_size_valid * n_workers - learning_rate = params.learning_rate * n_workers - - logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)", - learning_rate, batch_size_train, n_workers) - - # Datasets - tfrecord_train = TFRecords(params.training_record) if params.training_record else None - tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records] - - # Model instantiation - model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape) - - # TF.dataset-s instantiation - tf_ds_train = tfrecord_train.read(batch_size=batch_size_train, - target_keys=model.model_output_keys, - n_workers=n_workers, - shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None - tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid, - target_keys=model.model_output_keys, - n_workers=n_workers) for tfrecord in tfrecord_valid_array] - - with strategy.scope(): - # Creating the Keras network corresponding to the model - model.create_network() - - # Metrics - metrics_list = [metrics.MeanSquaredError(), metrics.PSNR()] - if params.all_metrics: - metrics_list += [metrics.StructuralSimilarity(), metrics.SpectralAngle()] # A bit slow to compute - - # Creating the model or loading it from checkpoints - logging.info("Loading model \"%s\"", params.model) - model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), - loss=model.get_loss(), - metrics={out_key: metrics_list for out_key in model.model_output_keys}) - model.summary(strategy) - - if params.plot_model: - model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy) - - callbacks = [] - # Define the checkpoint callback - if params.ckpt_dir: - if params.strategy == 'singlecpu': - logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints') - else: - backup_dir = os.path.join(params.ckpt_dir, params.model) - # Backup (deleted once the model is trained the specified number of epochs) - callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir)) - # Persistent save (still here after the model is trained) - callbacks.append(ArchiveCheckpoint(backup_dir, strategy)) - - # Define the Keras TensorBoard callback. - logdir = None - if params.logdir: - logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}") - tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir, - profile_batch=params.profiling) - callbacks.append(tensorboard_callback) - - # Define the previews callback - if params.previews: - # We run the preview on an arbitrary sample of the validation dataset - sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys) - previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys, - target_keys=model.model_output_keys) - callbacks.append(previews_callback) - - # Validation on multiple datasets - if tf_ds_valid: - additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir) - callbacks.append(additional_validation_callback) - - # Save best checkpoint only - if params.save_best: - callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True, - monitor=params.save_best_ref, mode='min')) - - # Early stopping if the training stops improving - if params.early_stopping: - callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001, - patience=10, mode='min')) - - # Training - model.fit(tf_ds_train, - epochs=params.epochs, - validation_data=tf_ds_valid[0] if tf_ds_valid else None, - callbacks=callbacks, - verbose=1 if params.verbose else 2) - - # Multiworker training tries to save the model multiple times and this can create corrupted models - # Thus we save the model at the final path only for the 'chief' worker - if params.strategy != 'singlecpu': - if not _is_chief(strategy): - out_savedmodel = None - - # Export SavedModel - if out_savedmodel and not params.save_best: - logging.info("Saving SavedModel in %s", out_savedmodel) - model.save(out_savedmodel) + # Date tag + date_tag = time.strftime("%d-%m-%y-%H%M%S") + + # adding the info to the SavedModel path + out_savedmodel = None if params.out_savedmodel is None else \ + os.path.join(params.out_savedmodel, expe_name + date_tag) + + # Scaling batch size and learning rate accordingly to number of workers + batch_size_train = params.batch_size_train * n_workers + batch_size_valid = params.batch_size_valid * n_workers + learning_rate = params.learning_rate * n_workers + + logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)", + learning_rate, batch_size_train, n_workers) + + # Datasets + tfrecord_train = TFRecords(params.training_record) if params.training_record else None + tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records] + + # Model instantiation + model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape) + + # TF.dataset-s instantiation + tf_ds_train = tfrecord_train.read(batch_size=batch_size_train, + target_keys=model.model_output_keys, + n_workers=n_workers, + shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None + tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid, + target_keys=model.model_output_keys, + n_workers=n_workers) for tfrecord in tfrecord_valid_array] + + with strategy.scope(): + # Creating the Keras network corresponding to the model + model.create_network() + + # Metrics + metrics_list = [metrics.MeanSquaredError, metrics.PSNR] + if params.all_metrics: + metrics_list += [metrics.StructuralSimilarity, metrics.SpectralAngle] # A bit slow to compute + + # Creating the model or loading it from checkpoints + logging.info("Loading model \"%s\"", params.model) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), + loss=model.get_loss(), + metrics={ + out_key: metric() + for out_key in model.model_output_keys + for metric in metrics_list + } + ) + model.summary(strategy) + + if params.plot_model: + model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy) + + callbacks = [] + # Define the checkpoint callback + if params.ckpt_dir: + if params.strategy == 'singlecpu': + logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints') + else: + backup_dir = os.path.join(params.ckpt_dir, params.model) + # Backup (deleted once the model is trained the specified number of epochs) + callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir)) + # Persistent save (still here after the model is trained) + callbacks.append(ArchiveCheckpoint(backup_dir, strategy)) + + # Define the Keras TensorBoard callback. + logdir = None + if params.logdir: + logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}") + tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir, + profile_batch=params.profiling) + callbacks.append(tensorboard_callback) + + # Define the previews callback + if params.previews: + # We run the preview on an arbitrary sample of the validation dataset + sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys) + previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys, + target_keys=model.model_output_keys) + callbacks.append(previews_callback) + + # Validation on multiple datasets + if tf_ds_valid: + additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir) + callbacks.append(additional_validation_callback) + + # Save best checkpoint only + if params.save_best: + callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True, + monitor=params.save_best_ref, mode='min')) + + # Early stopping if the training stops improving + if params.early_stopping: + callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001, + patience=10, mode='min')) + + # Training + model.fit(tf_ds_train, + epochs=params.epochs, + validation_data=tf_ds_valid[0] if tf_ds_valid else None, + callbacks=callbacks, + verbose=1 if params.verbose else 2) + + # Multiworker training tries to save the model multiple times and this can create corrupted models + # Thus we save the model at the final path only for the 'chief' worker + if params.strategy != 'singlecpu': + if not _is_chief(strategy): + out_savedmodel = None + + # Export SavedModel + if out_savedmodel and not params.save_best: + logging.info("Saving SavedModel in %s", out_savedmodel) + model.save(out_savedmodel) if __name__ == "__main__":