diff --git a/decloud/models/meraner_unet_all_bands.py b/decloud/models/meraner_unet_all_bands.py
index 206f16a4ff54f4a2101c4fa8cf8a20102b89f824..76e2389ea0033188bf7d7733af0b80a29b79b515 100644
--- a/decloud/models/meraner_unet_all_bands.py
+++ b/decloud/models/meraner_unet_all_bands.py
@@ -31,12 +31,25 @@ class meraner_unet_all_bands(Model):
     """
     Implementation of a variant of the Meraner et al. network (all bands)
     """
-    def __init__(self, dataset_shapes,
-                 dataset_input_keys=["s1_t", "s2_t", "s2_20m_t",
-                                     constants.DEM_KEY],
-                 model_output_keys=["s2_target"]):
-        super().__init__(dataset_input_keys=dataset_input_keys, model_output_keys=model_output_keys,
-                         dataset_shapes=dataset_shapes)
+    def __init__(
+            self,
+            dataset_shapes,
+            dataset_input_keys=[
+                "s1_t",
+                "s2_t",
+                "s2_20m_t",
+                constants.DEM_KEY
+            ],
+            model_output_keys=[
+                "s2_target",
+                "s2_20m_target"
+            ]
+    ):
+        super().__init__(
+            dataset_input_keys=dataset_input_keys,
+            model_output_keys=model_output_keys,
+            dataset_shapes=dataset_shapes
+        )
 
     def get_outputs(self, normalized_inputs):
         # The network
@@ -105,4 +118,8 @@ class meraner_unet_all_bands(Model):
         s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
         s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
 
-        return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
+        return {
+            "s2_target": s2_out,
+            "s2_20m_target": s2_20m_out,
+            "s2_all_bands_estim": s2_all_bands
+        }