diff --git a/decloud/models/train_from_tfrecords.py b/decloud/models/train_from_tfrecords.py
index eb58d4e449e841342867a547296a92d4cd4e11dc..75ef766905e5bb695107aab6a2b263d55bc1e662 100644
--- a/decloud/models/train_from_tfrecords.py
+++ b/decloud/models/train_from_tfrecords.py
@@ -146,117 +146,122 @@ def main(args):
     expe_name += "_e{}".format(params.epochs)
     expe_name += suffix
 
-    if True:  # TODO: detete, just used for review
-        # Date tag
-        date_tag = time.strftime("%d-%m-%y-%H%M%S")
-
-        # adding the info to the SavedModel path
-        out_savedmodel = None if params.out_savedmodel is None else \
-            os.path.join(params.out_savedmodel, expe_name + date_tag)
-
-        # Scaling batch size and learning rate accordingly to number of workers
-        batch_size_train = params.batch_size_train * n_workers
-        batch_size_valid = params.batch_size_valid * n_workers
-        learning_rate = params.learning_rate * n_workers
-
-        logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)",
-                     learning_rate, batch_size_train, n_workers)
-
-        # Datasets
-        tfrecord_train = TFRecords(params.training_record) if params.training_record else None
-        tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
-
-        # Model instantiation
-        model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
-
-        # TF.dataset-s instantiation
-        tf_ds_train = tfrecord_train.read(batch_size=batch_size_train,
-                                          target_keys=model.model_output_keys,
-                                          n_workers=n_workers,
-                                          shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None
-        tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid,
-                                     target_keys=model.model_output_keys,
-                                     n_workers=n_workers) for tfrecord in tfrecord_valid_array]
-
-        with strategy.scope():
-            # Creating the Keras network corresponding to the model
-            model.create_network()
-
-            # Metrics
-            metrics_list = [metrics.MeanSquaredError(), metrics.PSNR()]
-            if params.all_metrics:
-                metrics_list += [metrics.StructuralSimilarity(), metrics.SpectralAngle()]  # A bit slow to compute
-
-            # Creating the model or loading it from checkpoints
-            logging.info("Loading model \"%s\"", params.model)
-            model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
-                          loss=model.get_loss(),
-                          metrics={out_key: metrics_list for out_key in model.model_output_keys})
-            model.summary(strategy)
-
-            if params.plot_model:
-                model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy)
-
-            callbacks = []
-            # Define the checkpoint callback
-            if params.ckpt_dir:
-                if params.strategy == 'singlecpu':
-                    logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
-                else:
-                    backup_dir = os.path.join(params.ckpt_dir, params.model)
-                    # Backup (deleted once the model is trained the specified number of epochs)
-                    callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
-                    # Persistent save (still here after the model is trained)
-                    callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
-
-            # Define the Keras TensorBoard callback.
-            logdir = None
-            if params.logdir:
-                logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
-                tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
-                                                                   profile_batch=params.profiling)
-                callbacks.append(tensorboard_callback)
-
-                # Define the previews callback
-                if params.previews:
-                    # We run the preview on an arbitrary sample of the validation dataset
-                    sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys)
-                    previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
-                                                         target_keys=model.model_output_keys)
-                    callbacks.append(previews_callback)
-
-            # Validation on multiple datasets
-            if tf_ds_valid:
-                additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir)
-                callbacks.append(additional_validation_callback)
-
-            # Save best checkpoint only
-            if params.save_best:
-                callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True,
-                                                                 monitor=params.save_best_ref, mode='min'))
-
-            # Early stopping if the training stops improving
-            if params.early_stopping:
-                callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001,
-                                                               patience=10, mode='min'))
-
-            # Training
-            model.fit(tf_ds_train,
-                      epochs=params.epochs,
-                      validation_data=tf_ds_valid[0] if tf_ds_valid else None,
-                      callbacks=callbacks,
-                      verbose=1 if params.verbose else 2)
-
-            # Multiworker training tries to save the model multiple times and this can create corrupted models
-            # Thus we save the model at the final path only for the 'chief' worker
-            if params.strategy != 'singlecpu':
-                if not _is_chief(strategy):
-                    out_savedmodel = None
-
-            # Export SavedModel
-            if out_savedmodel and not params.save_best:
-                logging.info("Saving SavedModel in %s", out_savedmodel)
-                model.save(out_savedmodel)
+    # Date tag
+    date_tag = time.strftime("%d-%m-%y-%H%M%S")
+
+    # adding the info to the SavedModel path
+    out_savedmodel = None if params.out_savedmodel is None else \
+        os.path.join(params.out_savedmodel, expe_name + date_tag)
+
+    # Scaling batch size and learning rate accordingly to number of workers
+    batch_size_train = params.batch_size_train * n_workers
+    batch_size_valid = params.batch_size_valid * n_workers
+    learning_rate = params.learning_rate * n_workers
+
+    logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)",
+                 learning_rate, batch_size_train, n_workers)
+
+    # Datasets
+    tfrecord_train = TFRecords(params.training_record) if params.training_record else None
+    tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
+
+    # Model instantiation
+    model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
+
+    # TF.dataset-s instantiation
+    tf_ds_train = tfrecord_train.read(batch_size=batch_size_train,
+                                      target_keys=model.model_output_keys,
+                                      n_workers=n_workers,
+                                      shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None
+    tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid,
+                                 target_keys=model.model_output_keys,
+                                 n_workers=n_workers) for tfrecord in tfrecord_valid_array]
+
+    with strategy.scope():
+        # Creating the Keras network corresponding to the model
+        model.create_network()
+
+        # Metrics
+        metrics_list = [metrics.MeanSquaredError, metrics.PSNR]
+        if params.all_metrics:
+            metrics_list += [metrics.StructuralSimilarity, metrics.SpectralAngle]  # A bit slow to compute
+
+        # Creating the model or loading it from checkpoints
+        logging.info("Loading model \"%s\"", params.model)
+        model.compile(
+            optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
+            loss=model.get_loss(),
+            metrics={
+                out_key: metric()
+                for out_key in model.model_output_keys
+                for metric in metrics_list
+            }
+        )
+        model.summary(strategy)
+
+        if params.plot_model:
+            model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy)
+
+        callbacks = []
+        # Define the checkpoint callback
+        if params.ckpt_dir:
+            if params.strategy == 'singlecpu':
+                logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
+            else:
+                backup_dir = os.path.join(params.ckpt_dir, params.model)
+                # Backup (deleted once the model is trained the specified number of epochs)
+                callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
+                # Persistent save (still here after the model is trained)
+                callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
+
+        # Define the Keras TensorBoard callback.
+        logdir = None
+        if params.logdir:
+            logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
+            tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
+                                                               profile_batch=params.profiling)
+            callbacks.append(tensorboard_callback)
+
+            # Define the previews callback
+            if params.previews:
+                # We run the preview on an arbitrary sample of the validation dataset
+                sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys)
+                previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
+                                                     target_keys=model.model_output_keys)
+                callbacks.append(previews_callback)
+
+        # Validation on multiple datasets
+        if tf_ds_valid:
+            additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir)
+            callbacks.append(additional_validation_callback)
+
+        # Save best checkpoint only
+        if params.save_best:
+            callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True,
+                                                             monitor=params.save_best_ref, mode='min'))
+
+        # Early stopping if the training stops improving
+        if params.early_stopping:
+            callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001,
+                                                           patience=10, mode='min'))
+
+        # Training
+        model.fit(tf_ds_train,
+                  epochs=params.epochs,
+                  validation_data=tf_ds_valid[0] if tf_ds_valid else None,
+                  callbacks=callbacks,
+                  verbose=1 if params.verbose else 2)
+
+        # Multiworker training tries to save the model multiple times and this can create corrupted models
+        # Thus we save the model at the final path only for the 'chief' worker
+        if params.strategy != 'singlecpu':
+            if not _is_chief(strategy):
+                out_savedmodel = None
+
+        # Export SavedModel
+        if out_savedmodel and not params.save_best:
+            logging.info("Saving SavedModel in %s", out_savedmodel)
+            model.save(out_savedmodel)
 
 
 if __name__ == "__main__":