diff --git a/decloud/models/crga_os1_unet.py b/decloud/models/crga_os1_unet.py
index 4302cb263c022a085e594e7fad87ddb8fb505d2b..62e1ac182b295c0e7f7bd1ed1bec149aed8cda54 100644
--- a/decloud/models/crga_os1_unet.py
+++ b/decloud/models/crga_os1_unet.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 import decloud.preprocessing.constants as constants
 from decloud.models.crga_os1_base import crga_os1_base
+from tensorflow import concat
 
 
 class crga_os1_unet(crga_os1_base):
@@ -57,14 +58,14 @@ class crga_os1_unet(crga_os1_base):
             if input_image == "current":
                 net = conv1_s1(input_dict[input_image])  # 256
             else:
-                net = layers.concatenate(input_dict[input_image], axis=-1)
+                net = concat(input_dict[input_image], axis=-1)
                 net = conv1_s1s2(net)  # 256
 
             features[1].append(net)
             net = conv2(net)  # 128
             if self.has_dem():
                 net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
-                net = layers.concatenate([net, net_dem], axis=-1)
+                net = concat([net, net_dem], axis=-1)
             features[2].append(net)
             net = conv3(net)  # 64
             features[4].append(net)
@@ -79,7 +80,7 @@ class crga_os1_unet(crga_os1_base):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
diff --git a/decloud/models/crga_os1_unet_all_bands.py b/decloud/models/crga_os1_unet_all_bands.py
index 252dd1a50834a569c8161a76999984ec718417c1..ac044383e016e9f94dd5ee973801d8069c74e367 100644
--- a/decloud/models/crga_os1_unet_all_bands.py
+++ b/decloud/models/crga_os1_unet_all_bands.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 import decloud.preprocessing.constants as constants
 from decloud.models.crga_os1_base_all_bands import crga_os1_base_all_bands
+from tensorflow import concat
 
 
 class crga_os1_unet_all_bands(crga_os1_base_all_bands):
@@ -69,7 +70,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
                 features[1].append(net_10m)
                 net = conv2(net_10m)  # 128
             else:  # for post & ante, the is s1, s2 and s2_20m
-                net_10m = layers.concatenate(input_dict[input_image][:2], axis=-1)
+                net_10m = concat(input_dict[input_image][:2], axis=-1)
                 net_10m = conv1_s1s2(net_10m)  # 256
                 features[1].append(net_10m)
                 net_10m = conv2(net_10m)  # 128
@@ -77,7 +78,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
                 features_20m = [net_10m, net_20m]
                 if self.has_dem():
                     features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
-                net = layers.concatenate(features_20m, axis=-1)
+                net = concat(features_20m, axis=-1)
                 net = conv2_20m(net)  # 128
 
             features[2].append(net)
@@ -94,7 +95,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
@@ -114,6 +115,6 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
 
         # 10m-resampled stack that will be the output for inference (not used for training)
         s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
-        s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1)
+        s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
 
         return {"s2_t": s2_out, "s2_20m_t": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
diff --git a/decloud/models/crga_os2_david.py b/decloud/models/crga_os2_david.py
index 62efac978d57b77cdbc81763023318602a48399d..837d6121e23a7dd88d27a8aa40573f068efc22e1 100644
--- a/decloud/models/crga_os2_david.py
+++ b/decloud/models/crga_os2_david.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.crga_os2_base import crga_os2_base
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class crga_os2_david(crga_os2_base):
@@ -50,16 +51,16 @@ class crga_os2_david(crga_os2_base):
         deconv2 = layers.Conv2DTranspose(64, 3, 2, activation='relu', name="deconv2_bn_relu", padding="same")
         conv4 = layers.Conv2D(4, 5, 1, activation='relu', name="s2_estim", padding="same")
         for input_image in input_dict:
-            net = layers.concatenate(input_dict[input_image], axis=-1)
+            net = concat(input_dict[input_image], axis=-1)
             net = conv1(net)  # 256
             net = conv2(net)  # 128
             if self.has_dem():
                 net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
-                net = layers.concatenate([net, net_dem], axis=-1)
+                net = concat([net, net_dem], axis=-1)
             net = conv3(net)  # 64
             features.append(net)
 
-        net = layers.concatenate(features, axis=-1)
+        net = concat(features, axis=-1)
         net = deconv1(net)    # 128
         net = deconv2(net)    # 256
         s2_out = conv4(net)   # 256
diff --git a/decloud/models/crga_os2_david_all_bands.py b/decloud/models/crga_os2_david_all_bands.py
index 0b1f34b13ff4f0b597b133acf37267fe68e8bbe1..247fa1a5daa41f926fb338128d15eaac535f68fb 100644
--- a/decloud/models/crga_os2_david_all_bands.py
+++ b/decloud/models/crga_os2_david_all_bands.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.crga_os2_base_all_bands import crga_os2_base_all_bands
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class crga_os2_david_all_bands(crga_os2_base_all_bands):
@@ -56,7 +57,7 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands):
         conv4 = layers.Conv2D(4, 5, 1, activation='relu', name="s2_estim", padding="same")
         conv4_20m = layers.Conv2D(6, 3, 1, activation='relu', name="s2_20m_estim", padding="same")
         for input_image in input_dict:
-            net_10m = layers.concatenate(input_dict[input_image][:2], axis=-1)
+            net_10m = concat(input_dict[input_image][:2], axis=-1)
             net_10m = conv1(net_10m)  # 256
             net_10m = conv2(net_10m)  # 128
             net_20m = conv1_20m(input_dict[input_image][2])  # 128
@@ -64,11 +65,11 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands):
             features_20m = [net_10m, net_20m]
             if self.has_dem():
                 features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
-            net = layers.concatenate(features_20m, axis=-1)  # 128
+            net = concat(features_20m, axis=-1)  # 128
             net = conv3(net)  # 64
             features.append(net)
 
-        net = layers.concatenate(features, axis=-1)
+        net = concat(features, axis=-1)
         net = deconv1(net)  # 128
         net_10m = deconv2(net)  # 256
         net_20m = deconv2_20m(net)  # 128
@@ -78,6 +79,6 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands):
 
         # 10m-resampled stack that will be the output for inference (not used for training)
         s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
-        s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1)
+        s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
 
         return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
diff --git a/decloud/models/crga_os2_unet.py b/decloud/models/crga_os2_unet.py
index cefab53620f79149e0e06b2e9eff7cc887e9f94e..95c537b57d61e6140ec63fa7175961d6d183938e 100644
--- a/decloud/models/crga_os2_unet.py
+++ b/decloud/models/crga_os2_unet.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers, initializers
 from decloud.models.crga_os2_base import crga_os2_base
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class crga_os2_unet(crga_os2_base):
@@ -67,13 +68,13 @@ class crga_os2_unet(crga_os2_base):
                                    kernel_initializer=initializers.VarianceScaling())
 
         for input_image in input_dict:
-            net = layers.concatenate(input_dict[input_image], axis=-1)
+            net = concat(input_dict[input_image], axis=-1)
             net = conv1(net)  # 256
             features[1].append(net)
             net = conv2(net)  # 128
             if self.has_dem():
                 net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
-                net = layers.concatenate([net, net_dem], axis=-1)
+                net = concat([net, net_dem], axis=-1)
             features[2].append(net)
             net = conv3(net)  # 64
             features[4].append(net)
@@ -88,7 +89,7 @@ class crga_os2_unet(crga_os2_base):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
diff --git a/decloud/models/crga_os2_unet_all_bands.py b/decloud/models/crga_os2_unet_all_bands.py
index fad9e77c6d73d9f983aad40c032dc756a2302124..7e3c0543473928a6a596f35b42d88133af2babe2 100644
--- a/decloud/models/crga_os2_unet_all_bands.py
+++ b/decloud/models/crga_os2_unet_all_bands.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.crga_os2_base_all_bands import crga_os2_base_all_bands
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class crga_os2_unet_all_bands(crga_os2_base_all_bands):
@@ -65,7 +66,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
         # The network
         features = {factor: [] for factor in [1, 2, 4, 8, 16, 32]}
         for input_image in input_dict:
-            net_10m = layers.concatenate(input_dict[input_image][:2], axis=-1)
+            net_10m = concat(input_dict[input_image][:2], axis=-1)
             net_10m = conv1(net_10m)  # 256
             features[1].append(net_10m)
             net_10m = conv2(net_10m)  # 128
@@ -74,7 +75,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
             features_20m = [net_10m, net_20m]
             if self.has_dem():
                 features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
-            net = layers.concatenate(features_20m, axis=-1)  # 128
+            net = concat(features_20m, axis=-1)  # 128
             features[2].append(net)
             net = conv3(net)  # 64
             features[4].append(net)
@@ -89,7 +90,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
@@ -109,6 +110,6 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
 
         # 10m-resampled stack that will be the output for inference (not used for training)
         s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
-        s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1)
+        s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
 
         return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
diff --git a/decloud/models/meraner_original.py b/decloud/models/meraner_original.py
index 3aee97a342274d857e7b5dcfaad3821f5f431d97..a4302195f6a30fb510724e51a5c543397f80a236 100644
--- a/decloud/models/meraner_original.py
+++ b/decloud/models/meraner_original.py
@@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
 """Implementation of the Meraner et al. original network"""
 from tensorflow.keras import layers
 from decloud.models.model import Model
+from tensorflow import concat
 
 
 class meraner_original(Model):
@@ -50,7 +51,7 @@ class meraner_original(Model):
 
         # The network
         conv1 = layers.Conv2D(resblocks_dim, 3, 1, activation='relu', name="conv1_relu", padding="same")
-        net = layers.concatenate([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
+        net = concat([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
         net = conv1(net)
         for i in range(n_resblocks):
             net = _resblock(net, i)
diff --git a/decloud/models/meraner_unet.py b/decloud/models/meraner_unet.py
index c2b6141f5cb5995e9386a08c56f34faf9c966395..1696869e21cf08ec93d039f91617cf7c295f1fd9 100644
--- a/decloud/models/meraner_unet.py
+++ b/decloud/models/meraner_unet.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.model import Model
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class meraner_unet(Model):
@@ -56,13 +57,13 @@ class meraner_unet(Model):
         deconv5 = layers.Conv2DTranspose(64, 3, 2, activation='relu', name="deconv5_bn_relu", padding="same")
         conv_final = layers.Conv2D(4, 5, 1, name="s2_estim", padding="same")
 
-        net = layers.concatenate([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
+        net = concat([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
         net = conv1(net)  # 256
         features[1].append(net)
         net = conv2(net)  # 128
         if self.has_dem():
             net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
-            net = layers.concatenate([net, net_dem], axis=-1)
+            net = concat([net, net_dem], axis=-1)
         features[2].append(net)
         net = conv3(net)  # 64
         features[4].append(net)
@@ -76,7 +77,7 @@ class meraner_unet(Model):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = deconv1(net)  # 16
         net = _combine(factor=16, x=net)
diff --git a/decloud/models/meraner_unet_all_bands.py b/decloud/models/meraner_unet_all_bands.py
index cbee50a9b46fa3e5854e6b19e75ab87d48a79adf..206f16a4ff54f4a2101c4fa8cf8a20102b89f824 100644
--- a/decloud/models/meraner_unet_all_bands.py
+++ b/decloud/models/meraner_unet_all_bands.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 import decloud.preprocessing.constants as constants
 from decloud.models.model import Model
 from tensorflow.keras import layers
+from tensorflow import concat
 
 
 class meraner_unet_all_bands(Model):
@@ -58,7 +59,7 @@ class meraner_unet_all_bands(Model):
         conv_final = layers.Conv2D(4, 5, 1, name="s2_estim", padding="same")
         conv_20m_final = layers.Conv2D(6, 3, 1, name="s2_20m_estim", padding="same")
 
-        net_10m = layers.concatenate([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
+        net_10m = concat([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
         net_10m = conv1(net_10m)  # 256
         features[1].append(net_10m)
         net_10m = conv2(net_10m)  # 128
@@ -67,7 +68,7 @@ class meraner_unet_all_bands(Model):
         features_20m = [net_10m, net_20m]
         if self.has_dem():
             features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
-        net = layers.concatenate(features_20m, axis=-1)  # 128
+        net = concat(features_20m, axis=-1)  # 128
 
         features[2].append(net)
         net = conv3(net)  # 64
@@ -83,7 +84,7 @@ class meraner_unet_all_bands(Model):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
@@ -102,6 +103,6 @@ class meraner_unet_all_bands(Model):
 
         # 10m-resampled stack that will be the output for inference (not used for training)
         s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
-        s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1)
+        s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
 
         return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
diff --git a/decloud/models/monthly_synthesis_6_s2_images.py b/decloud/models/monthly_synthesis_6_s2_images.py
index 5ed1a8ad65b616f628f184e1dc576433a9d49697..fc04ed33d1807c0ccd19d1b7c9fa9bde1e1467af 100644
--- a/decloud/models/monthly_synthesis_6_s2_images.py
+++ b/decloud/models/monthly_synthesis_6_s2_images.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.model import Model
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class monthly_synthesis_6_s2_images(Model):
@@ -55,7 +56,7 @@ class monthly_synthesis_6_s2_images(Model):
             features[1].append(net)
             net = conv2(net)  # 128
             if self.has_dem():
-                net = layers.concatenate([net, normalized_inputs[constants.DEM_KEY]], axis=-1)
+                net = concat([net, normalized_inputs[constants.DEM_KEY]], axis=-1)
             features[2].append(net)
             net = conv3(net)  # 64
             features[4].append(net)
@@ -70,7 +71,7 @@ class monthly_synthesis_6_s2_images(Model):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
diff --git a/decloud/models/monthly_synthesis_6_s2_images_david.py b/decloud/models/monthly_synthesis_6_s2_images_david.py
index 3f10ca65b00b91b7c47626653920f06316ca283b..5d1e744186f4c8a843e3ef6da1539feaa3dce8db 100644
--- a/decloud/models/monthly_synthesis_6_s2_images_david.py
+++ b/decloud/models/monthly_synthesis_6_s2_images_david.py
@@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
 """David model implementation (monthly synthesis of 6 optical images)"""
 from tensorflow.keras import layers
 from decloud.models.model import Model
+from tensorflow import concat
 
 
 class monthly_synthesis_6_s2_images_david(Model):
@@ -48,7 +49,7 @@ class monthly_synthesis_6_s2_images_david(Model):
             net = conv3(net)  # 64
             features.append(net)
 
-        net = layers.concatenate(features, axis=-1)
+        net = concat(features, axis=-1)
         net = deconv1(net)  # 128
         net = deconv2(net)  # 256
         s2_out = conv4(net)  # 256
diff --git a/decloud/models/monthly_synthesis_6_s2s1_images.py b/decloud/models/monthly_synthesis_6_s2s1_images.py
index 8b1f8fb292fa9c92d7a3c8d2a85072160fe0fe72..838a9fe338b1f3e2351f655e722c290d72191b57 100644
--- a/decloud/models/monthly_synthesis_6_s2s1_images.py
+++ b/decloud/models/monthly_synthesis_6_s2s1_images.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.model import Model
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class monthly_synthesis_6_s2s1_images(Model):
@@ -64,7 +65,7 @@ class monthly_synthesis_6_s2s1_images(Model):
                     features[1].append(net)
                     net = conv2_s2(net)
                 if self.has_dem():
-                    net = layers.concatenate([net, normalized_inputs[constants.DEM_KEY]], axis=-1)
+                    net = concat([net, normalized_inputs[constants.DEM_KEY]], axis=-1)
                 features[2].append(net)
                 net = conv3(net)  # 64
                 features[4].append(net)
@@ -79,7 +80,7 @@ class monthly_synthesis_6_s2s1_images(Model):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
diff --git a/decloud/models/monthly_synthesis_6_s2s1_images_david.py b/decloud/models/monthly_synthesis_6_s2s1_images_david.py
index 1c3ec0239a3b51d83462da9640c1889bef35013b..123053352edf00ca3b432ed56eb66c2ce8386e4c 100644
--- a/decloud/models/monthly_synthesis_6_s2s1_images_david.py
+++ b/decloud/models/monthly_synthesis_6_s2s1_images_david.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.model import Model
 from decloud.preprocessing import constants
+from tensorflow import concat
 
 
 class monthly_synthesis_6_s2s1_images_david(Model):
@@ -56,11 +57,11 @@ class monthly_synthesis_6_s2s1_images_david(Model):
                 net = conv2(net)  # 128
                 if self.has_dem():
                     net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
-                    net = layers.concatenate([net, net_dem], axis=-1)
+                    net = concat([net, net_dem], axis=-1)
                 net = conv3(net)  # 64
                 features.append(net)
 
-        net = layers.concatenate(features, axis=-1)
+        net = concat(features, axis=-1)
         net = deconv1(net)  # 128
         net = deconv2(net)  # 256
         s2_out = conv4(net)  # 256