diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 86e699f17d3e6258df2119fffdae63d49faacb10..958c269fd8b39babadb477e1348dd68f0c5965b3 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -4,8 +4,8 @@ variables:
   GPU_IMAGE_NAME: $CI_REGISTRY_IMAGE:gpu
   DOCKER_BUILDKIT: 1
   DOCKER_DRIVER: overlay2
-  CPU_BASE_IMAGE: gitlab-registry.irstea.fr/remi.cresson/otbtf:3.3.2-cpu-dev
-  GPU_BASE_IMAGE: gitlab-registry.irstea.fr/remi.cresson/otbtf:3.3.2-gpu-dev
+  CPU_BASE_IMAGE: gitlab-registry.irstea.fr/remi.cresson/otbtf:4.1.0-cpu-dev
+  GPU_BASE_IMAGE: gitlab-registry.irstea.fr/remi.cresson/otbtf:4.1.0-gpu-dev
 
 workflow:
   rules:
@@ -57,17 +57,17 @@ Build the docker image:
 flake8:
   extends: .static_analysis_base
   script:
-   - sudo apt update && sudo apt install -y flake8 && python -m flake8 --ignore=E402 --max-line-length=120 $PWD/decloud
+   - pip install flake8 && flake8 --ignore=E402 --max-line-length=120 $PWD/decloud
 
 pylint:
   extends: .static_analysis_base
   script:
-  - sudo apt update && sudo apt install -y pylint && pylint --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments,invalid-name,cell-var-from-loop,too-many-branches,too-many-ancestors --ignored-modules=tensorflow,git,rtree,scipy,tensorboard,libamalthee,pandas --max-line-length=120 $PWD/decloud
+   - pip install pylint && pylint --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments,invalid-name,cell-var-from-loop,too-many-branches,too-many-ancestors --ignored-modules=tensorflow,git,rtree,scipy,tensorboard,libamalthee,pandas --max-line-length=120 $PWD/decloud
 
 codespell:
   extends: .static_analysis_base
   script:
-    - sudo pip install codespell && codespell --skip="*.png,*.template,*.pbs,*.jpg,*git/lfs*"
+   - pip install codespell && codespell --skip="*.png,*.template,*.pbs,*.jpg,*git/lfs*"
 
 .applications_test_base:
   image: $TEST_IMAGE_NAME
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 2af0b57f382da9ece34be4f433414a341e7f334b..685c42c13c1bef5d1ee42f2fe440508ad100c5da 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,5 +1,5 @@
 Contributors:
-- Rémi Cresson (remi.cresson@inrae.fr)
-- Nicolas Narçon (nicolas.narcon@inrae.fr)
-- Benjamin Commandre (benjamin.commandre@inrae.fr)
-- Raffaele Gaetano (raffaele.gaetano@cirad.fr)
+- Rémi Cresson (remi.cresson at inrae.fr)
+- Nicolas Narçon
+- Benjamin Commandre
+- Raffaele Gaetano
diff --git a/Dockerfile b/Dockerfile
index 43814a853dac96e115b205194b04b404e355a64c..31f4979df227f27eafad185c6527a646c5fdc2be 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,9 +1,9 @@
 # Decloud dockerfile
 # To build the docker image for cpu, do the following:
 #
-# docker build --build-arg "BASE_IMAGE=mdl4eo/otbtf:3.3.2-cpu-dev" .
+# docker build --build-arg "BASE_IMAGE=mdl4eo/otbtf:4.1.0-cpu-dev" .
 #
-ARG BASE_IMAGE=mdl4eo/otbtf:3.3.2-gpu-dev
+ARG BASE_IMAGE=mdl4eo/otbtf:4.1.0-gpu-dev
 FROM $BASE_IMAGE
 LABEL description="Decloud docker image"
 LABEL maintainer="Remi Cresson [at] inrae [dot] fr"
diff --git a/LICENSE b/LICENSE
index 0e166bdc96d6f375116f2e686bf42f38b8b94da0..37862d75bd5513d69ac53834ce7b8df39b23de6f 100644
--- a/LICENSE
+++ b/LICENSE
@@ -186,7 +186,7 @@
       same "printed page" as the copyright notice for easier
       identification within third-party archives.
 
-   Copyright 2020 INRAE
+   Copyright 2020-2023 INRAE
 
    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
diff --git a/README.md b/README.md
index f249fa83114b0c0e4104db8b5294e176b90d2464..f2be459f2fcb408dafe9539f9d8881fc8d5ebfb5 100644
--- a/README.md
+++ b/README.md
@@ -12,10 +12,12 @@ Representative illustrations:
 ## Cite
 
 ```
-@article{cresson2022clouds,
+@inproceedings{cresson2022comparison,
   title={Comparison of convolutional neural networks for cloudy optical images reconstruction from single or multitemporal joint SAR and optical images},
-  author={Cresson, R., Narcon, N., Gaetano, R., Dupuis A., Tanguy, Y., May, S., Commandre, B.},
-  journal={arXiv preprint arXiv:2204.00424},
+  author={Cresson, R{\'e}mi and Nar{\c{c}}on, N and Gaetano, Raffaele and Dupuis, Aurore and Tanguy, Yannick and May, St{\'e}phane and Commandr{\'e}, Benjamin},
+  booktitle={XXIV ISPRS Congress (2022 edition)},
+  volume={43},
+  pages={1317--1326},
   year={2022}
 }
 ```
diff --git a/decloud/models/create_tfrecords.py b/decloud/models/create_tfrecords.py
index d9993007c7366aceff51d0a7c937cfcc7ba801f2..21b8ad547a8779f88580e891b3cd027fa7e8d706 100644
--- a/decloud/models/create_tfrecords.py
+++ b/decloud/models/create_tfrecords.py
@@ -22,7 +22,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 DEALINGS IN THE SOFTWARE.
 """
 """Create some TFRecords from a decloud.dataset"""
-import os
 import argparse
 import sys
 import logging
diff --git a/decloud/models/crga_os1_unet.py b/decloud/models/crga_os1_unet.py
index 4302cb263c022a085e594e7fad87ddb8fb505d2b..62e1ac182b295c0e7f7bd1ed1bec149aed8cda54 100644
--- a/decloud/models/crga_os1_unet.py
+++ b/decloud/models/crga_os1_unet.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 import decloud.preprocessing.constants as constants
 from decloud.models.crga_os1_base import crga_os1_base
+from tensorflow import concat
 
 
 class crga_os1_unet(crga_os1_base):
@@ -57,14 +58,14 @@ class crga_os1_unet(crga_os1_base):
             if input_image == "current":
                 net = conv1_s1(input_dict[input_image])  # 256
             else:
-                net = layers.concatenate(input_dict[input_image], axis=-1)
+                net = concat(input_dict[input_image], axis=-1)
                 net = conv1_s1s2(net)  # 256
 
             features[1].append(net)
             net = conv2(net)  # 128
             if self.has_dem():
                 net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
-                net = layers.concatenate([net, net_dem], axis=-1)
+                net = concat([net, net_dem], axis=-1)
             features[2].append(net)
             net = conv3(net)  # 64
             features[4].append(net)
@@ -79,7 +80,7 @@ class crga_os1_unet(crga_os1_base):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
diff --git a/decloud/models/crga_os1_unet_all_bands.py b/decloud/models/crga_os1_unet_all_bands.py
index 252dd1a50834a569c8161a76999984ec718417c1..ac044383e016e9f94dd5ee973801d8069c74e367 100644
--- a/decloud/models/crga_os1_unet_all_bands.py
+++ b/decloud/models/crga_os1_unet_all_bands.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 import decloud.preprocessing.constants as constants
 from decloud.models.crga_os1_base_all_bands import crga_os1_base_all_bands
+from tensorflow import concat
 
 
 class crga_os1_unet_all_bands(crga_os1_base_all_bands):
@@ -69,7 +70,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
                 features[1].append(net_10m)
                 net = conv2(net_10m)  # 128
             else:  # for post & ante, the is s1, s2 and s2_20m
-                net_10m = layers.concatenate(input_dict[input_image][:2], axis=-1)
+                net_10m = concat(input_dict[input_image][:2], axis=-1)
                 net_10m = conv1_s1s2(net_10m)  # 256
                 features[1].append(net_10m)
                 net_10m = conv2(net_10m)  # 128
@@ -77,7 +78,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
                 features_20m = [net_10m, net_20m]
                 if self.has_dem():
                     features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
-                net = layers.concatenate(features_20m, axis=-1)
+                net = concat(features_20m, axis=-1)
                 net = conv2_20m(net)  # 128
 
             features[2].append(net)
@@ -94,7 +95,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
@@ -114,6 +115,6 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
 
         # 10m-resampled stack that will be the output for inference (not used for training)
         s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
-        s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1)
+        s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
 
         return {"s2_t": s2_out, "s2_20m_t": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
diff --git a/decloud/models/crga_os2_david.py b/decloud/models/crga_os2_david.py
index 62efac978d57b77cdbc81763023318602a48399d..837d6121e23a7dd88d27a8aa40573f068efc22e1 100644
--- a/decloud/models/crga_os2_david.py
+++ b/decloud/models/crga_os2_david.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.crga_os2_base import crga_os2_base
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class crga_os2_david(crga_os2_base):
@@ -50,16 +51,16 @@ class crga_os2_david(crga_os2_base):
         deconv2 = layers.Conv2DTranspose(64, 3, 2, activation='relu', name="deconv2_bn_relu", padding="same")
         conv4 = layers.Conv2D(4, 5, 1, activation='relu', name="s2_estim", padding="same")
         for input_image in input_dict:
-            net = layers.concatenate(input_dict[input_image], axis=-1)
+            net = concat(input_dict[input_image], axis=-1)
             net = conv1(net)  # 256
             net = conv2(net)  # 128
             if self.has_dem():
                 net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
-                net = layers.concatenate([net, net_dem], axis=-1)
+                net = concat([net, net_dem], axis=-1)
             net = conv3(net)  # 64
             features.append(net)
 
-        net = layers.concatenate(features, axis=-1)
+        net = concat(features, axis=-1)
         net = deconv1(net)    # 128
         net = deconv2(net)    # 256
         s2_out = conv4(net)   # 256
diff --git a/decloud/models/crga_os2_david_all_bands.py b/decloud/models/crga_os2_david_all_bands.py
index 0b1f34b13ff4f0b597b133acf37267fe68e8bbe1..247fa1a5daa41f926fb338128d15eaac535f68fb 100644
--- a/decloud/models/crga_os2_david_all_bands.py
+++ b/decloud/models/crga_os2_david_all_bands.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.crga_os2_base_all_bands import crga_os2_base_all_bands
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class crga_os2_david_all_bands(crga_os2_base_all_bands):
@@ -56,7 +57,7 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands):
         conv4 = layers.Conv2D(4, 5, 1, activation='relu', name="s2_estim", padding="same")
         conv4_20m = layers.Conv2D(6, 3, 1, activation='relu', name="s2_20m_estim", padding="same")
         for input_image in input_dict:
-            net_10m = layers.concatenate(input_dict[input_image][:2], axis=-1)
+            net_10m = concat(input_dict[input_image][:2], axis=-1)
             net_10m = conv1(net_10m)  # 256
             net_10m = conv2(net_10m)  # 128
             net_20m = conv1_20m(input_dict[input_image][2])  # 128
@@ -64,11 +65,11 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands):
             features_20m = [net_10m, net_20m]
             if self.has_dem():
                 features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
-            net = layers.concatenate(features_20m, axis=-1)  # 128
+            net = concat(features_20m, axis=-1)  # 128
             net = conv3(net)  # 64
             features.append(net)
 
-        net = layers.concatenate(features, axis=-1)
+        net = concat(features, axis=-1)
         net = deconv1(net)  # 128
         net_10m = deconv2(net)  # 256
         net_20m = deconv2_20m(net)  # 128
@@ -78,6 +79,6 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands):
 
         # 10m-resampled stack that will be the output for inference (not used for training)
         s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
-        s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1)
+        s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
 
         return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
diff --git a/decloud/models/crga_os2_unet.py b/decloud/models/crga_os2_unet.py
index cefab53620f79149e0e06b2e9eff7cc887e9f94e..95c537b57d61e6140ec63fa7175961d6d183938e 100644
--- a/decloud/models/crga_os2_unet.py
+++ b/decloud/models/crga_os2_unet.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers, initializers
 from decloud.models.crga_os2_base import crga_os2_base
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class crga_os2_unet(crga_os2_base):
@@ -67,13 +68,13 @@ class crga_os2_unet(crga_os2_base):
                                    kernel_initializer=initializers.VarianceScaling())
 
         for input_image in input_dict:
-            net = layers.concatenate(input_dict[input_image], axis=-1)
+            net = concat(input_dict[input_image], axis=-1)
             net = conv1(net)  # 256
             features[1].append(net)
             net = conv2(net)  # 128
             if self.has_dem():
                 net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
-                net = layers.concatenate([net, net_dem], axis=-1)
+                net = concat([net, net_dem], axis=-1)
             features[2].append(net)
             net = conv3(net)  # 64
             features[4].append(net)
@@ -88,7 +89,7 @@ class crga_os2_unet(crga_os2_base):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
diff --git a/decloud/models/crga_os2_unet_all_bands.py b/decloud/models/crga_os2_unet_all_bands.py
index fad9e77c6d73d9f983aad40c032dc756a2302124..7e3c0543473928a6a596f35b42d88133af2babe2 100644
--- a/decloud/models/crga_os2_unet_all_bands.py
+++ b/decloud/models/crga_os2_unet_all_bands.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.crga_os2_base_all_bands import crga_os2_base_all_bands
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class crga_os2_unet_all_bands(crga_os2_base_all_bands):
@@ -65,7 +66,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
         # The network
         features = {factor: [] for factor in [1, 2, 4, 8, 16, 32]}
         for input_image in input_dict:
-            net_10m = layers.concatenate(input_dict[input_image][:2], axis=-1)
+            net_10m = concat(input_dict[input_image][:2], axis=-1)
             net_10m = conv1(net_10m)  # 256
             features[1].append(net_10m)
             net_10m = conv2(net_10m)  # 128
@@ -74,7 +75,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
             features_20m = [net_10m, net_20m]
             if self.has_dem():
                 features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
-            net = layers.concatenate(features_20m, axis=-1)  # 128
+            net = concat(features_20m, axis=-1)  # 128
             features[2].append(net)
             net = conv3(net)  # 64
             features[4].append(net)
@@ -89,7 +90,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
@@ -109,6 +110,6 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
 
         # 10m-resampled stack that will be the output for inference (not used for training)
         s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
-        s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1)
+        s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
 
         return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
diff --git a/decloud/models/meraner_original.py b/decloud/models/meraner_original.py
index 3aee97a342274d857e7b5dcfaad3821f5f431d97..a4302195f6a30fb510724e51a5c543397f80a236 100644
--- a/decloud/models/meraner_original.py
+++ b/decloud/models/meraner_original.py
@@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
 """Implementation of the Meraner et al. original network"""
 from tensorflow.keras import layers
 from decloud.models.model import Model
+from tensorflow import concat
 
 
 class meraner_original(Model):
@@ -50,7 +51,7 @@ class meraner_original(Model):
 
         # The network
         conv1 = layers.Conv2D(resblocks_dim, 3, 1, activation='relu', name="conv1_relu", padding="same")
-        net = layers.concatenate([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
+        net = concat([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
         net = conv1(net)
         for i in range(n_resblocks):
             net = _resblock(net, i)
diff --git a/decloud/models/meraner_unet.py b/decloud/models/meraner_unet.py
index c2b6141f5cb5995e9386a08c56f34faf9c966395..1696869e21cf08ec93d039f91617cf7c295f1fd9 100644
--- a/decloud/models/meraner_unet.py
+++ b/decloud/models/meraner_unet.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.model import Model
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class meraner_unet(Model):
@@ -56,13 +57,13 @@ class meraner_unet(Model):
         deconv5 = layers.Conv2DTranspose(64, 3, 2, activation='relu', name="deconv5_bn_relu", padding="same")
         conv_final = layers.Conv2D(4, 5, 1, name="s2_estim", padding="same")
 
-        net = layers.concatenate([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
+        net = concat([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
         net = conv1(net)  # 256
         features[1].append(net)
         net = conv2(net)  # 128
         if self.has_dem():
             net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
-            net = layers.concatenate([net, net_dem], axis=-1)
+            net = concat([net, net_dem], axis=-1)
         features[2].append(net)
         net = conv3(net)  # 64
         features[4].append(net)
@@ -76,7 +77,7 @@ class meraner_unet(Model):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = deconv1(net)  # 16
         net = _combine(factor=16, x=net)
diff --git a/decloud/models/meraner_unet_all_bands.py b/decloud/models/meraner_unet_all_bands.py
index cbee50a9b46fa3e5854e6b19e75ab87d48a79adf..76e2389ea0033188bf7d7733af0b80a29b79b515 100644
--- a/decloud/models/meraner_unet_all_bands.py
+++ b/decloud/models/meraner_unet_all_bands.py
@@ -24,18 +24,32 @@ DEALINGS IN THE SOFTWARE.
 import decloud.preprocessing.constants as constants
 from decloud.models.model import Model
 from tensorflow.keras import layers
+from tensorflow import concat
 
 
 class meraner_unet_all_bands(Model):
     """
     Implementation of a variant of the Meraner et al. network (all bands)
     """
-    def __init__(self, dataset_shapes,
-                 dataset_input_keys=["s1_t", "s2_t", "s2_20m_t",
-                                     constants.DEM_KEY],
-                 model_output_keys=["s2_target"]):
-        super().__init__(dataset_input_keys=dataset_input_keys, model_output_keys=model_output_keys,
-                         dataset_shapes=dataset_shapes)
+    def __init__(
+            self,
+            dataset_shapes,
+            dataset_input_keys=[
+                "s1_t",
+                "s2_t",
+                "s2_20m_t",
+                constants.DEM_KEY
+            ],
+            model_output_keys=[
+                "s2_target",
+                "s2_20m_target"
+            ]
+    ):
+        super().__init__(
+            dataset_input_keys=dataset_input_keys,
+            model_output_keys=model_output_keys,
+            dataset_shapes=dataset_shapes
+        )
 
     def get_outputs(self, normalized_inputs):
         # The network
@@ -58,7 +72,7 @@ class meraner_unet_all_bands(Model):
         conv_final = layers.Conv2D(4, 5, 1, name="s2_estim", padding="same")
         conv_20m_final = layers.Conv2D(6, 3, 1, name="s2_20m_estim", padding="same")
 
-        net_10m = layers.concatenate([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
+        net_10m = concat([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
         net_10m = conv1(net_10m)  # 256
         features[1].append(net_10m)
         net_10m = conv2(net_10m)  # 128
@@ -67,7 +81,7 @@ class meraner_unet_all_bands(Model):
         features_20m = [net_10m, net_20m]
         if self.has_dem():
             features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
-        net = layers.concatenate(features_20m, axis=-1)  # 128
+        net = concat(features_20m, axis=-1)  # 128
 
         features[2].append(net)
         net = conv3(net)  # 64
@@ -83,7 +97,7 @@ class meraner_unet_all_bands(Model):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
@@ -102,6 +116,10 @@ class meraner_unet_all_bands(Model):
 
         # 10m-resampled stack that will be the output for inference (not used for training)
         s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
-        s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1)
+        s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
 
-        return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
+        return {
+            "s2_target": s2_out,
+            "s2_20m_target": s2_20m_out,
+            "s2_all_bands_estim": s2_all_bands
+        }
diff --git a/decloud/models/monthly_synthesis_6_s2_images.py b/decloud/models/monthly_synthesis_6_s2_images.py
index 5ed1a8ad65b616f628f184e1dc576433a9d49697..fc04ed33d1807c0ccd19d1b7c9fa9bde1e1467af 100644
--- a/decloud/models/monthly_synthesis_6_s2_images.py
+++ b/decloud/models/monthly_synthesis_6_s2_images.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.model import Model
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class monthly_synthesis_6_s2_images(Model):
@@ -55,7 +56,7 @@ class monthly_synthesis_6_s2_images(Model):
             features[1].append(net)
             net = conv2(net)  # 128
             if self.has_dem():
-                net = layers.concatenate([net, normalized_inputs[constants.DEM_KEY]], axis=-1)
+                net = concat([net, normalized_inputs[constants.DEM_KEY]], axis=-1)
             features[2].append(net)
             net = conv3(net)  # 64
             features[4].append(net)
@@ -70,7 +71,7 @@ class monthly_synthesis_6_s2_images(Model):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
diff --git a/decloud/models/monthly_synthesis_6_s2_images_david.py b/decloud/models/monthly_synthesis_6_s2_images_david.py
index 3f10ca65b00b91b7c47626653920f06316ca283b..5d1e744186f4c8a843e3ef6da1539feaa3dce8db 100644
--- a/decloud/models/monthly_synthesis_6_s2_images_david.py
+++ b/decloud/models/monthly_synthesis_6_s2_images_david.py
@@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
 """David model implementation (monthly synthesis of 6 optical images)"""
 from tensorflow.keras import layers
 from decloud.models.model import Model
+from tensorflow import concat
 
 
 class monthly_synthesis_6_s2_images_david(Model):
@@ -48,7 +49,7 @@ class monthly_synthesis_6_s2_images_david(Model):
             net = conv3(net)  # 64
             features.append(net)
 
-        net = layers.concatenate(features, axis=-1)
+        net = concat(features, axis=-1)
         net = deconv1(net)  # 128
         net = deconv2(net)  # 256
         s2_out = conv4(net)  # 256
diff --git a/decloud/models/monthly_synthesis_6_s2s1_images.py b/decloud/models/monthly_synthesis_6_s2s1_images.py
index 8b1f8fb292fa9c92d7a3c8d2a85072160fe0fe72..838a9fe338b1f3e2351f655e722c290d72191b57 100644
--- a/decloud/models/monthly_synthesis_6_s2s1_images.py
+++ b/decloud/models/monthly_synthesis_6_s2s1_images.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.model import Model
 import decloud.preprocessing.constants as constants
+from tensorflow import concat
 
 
 class monthly_synthesis_6_s2s1_images(Model):
@@ -64,7 +65,7 @@ class monthly_synthesis_6_s2s1_images(Model):
                     features[1].append(net)
                     net = conv2_s2(net)
                 if self.has_dem():
-                    net = layers.concatenate([net, normalized_inputs[constants.DEM_KEY]], axis=-1)
+                    net = concat([net, normalized_inputs[constants.DEM_KEY]], axis=-1)
                 features[2].append(net)
                 net = conv3(net)  # 64
                 features[4].append(net)
@@ -79,7 +80,7 @@ class monthly_synthesis_6_s2s1_images(Model):
         def _combine(factor, x=None):
             if x is not None:
                 features[factor].append(x)
-            return layers.concatenate(features[factor], axis=-1)
+            return concat(features[factor], axis=-1)
 
         net = _combine(factor=32)
         net = deconv1(net)  # 16
diff --git a/decloud/models/monthly_synthesis_6_s2s1_images_david.py b/decloud/models/monthly_synthesis_6_s2s1_images_david.py
index 1c3ec0239a3b51d83462da9640c1889bef35013b..123053352edf00ca3b432ed56eb66c2ce8386e4c 100644
--- a/decloud/models/monthly_synthesis_6_s2s1_images_david.py
+++ b/decloud/models/monthly_synthesis_6_s2s1_images_david.py
@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
 from tensorflow.keras import layers
 from decloud.models.model import Model
 from decloud.preprocessing import constants
+from tensorflow import concat
 
 
 class monthly_synthesis_6_s2s1_images_david(Model):
@@ -56,11 +57,11 @@ class monthly_synthesis_6_s2s1_images_david(Model):
                 net = conv2(net)  # 128
                 if self.has_dem():
                     net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
-                    net = layers.concatenate([net, net_dem], axis=-1)
+                    net = concat([net, net_dem], axis=-1)
                 net = conv3(net)  # 64
                 features.append(net)
 
-        net = layers.concatenate(features, axis=-1)
+        net = concat(features, axis=-1)
         net = deconv1(net)  # 128
         net = deconv2(net)  # 256
         s2_out = conv4(net)  # 256
diff --git a/decloud/models/train_from_tfrecords.py b/decloud/models/train_from_tfrecords.py
index eb58d4e449e841342867a547296a92d4cd4e11dc..75ef766905e5bb695107aab6a2b263d55bc1e662 100644
--- a/decloud/models/train_from_tfrecords.py
+++ b/decloud/models/train_from_tfrecords.py
@@ -146,117 +146,122 @@ def main(args):
     expe_name += "_e{}".format(params.epochs)
     expe_name += suffix
 
-    if True:  # TODO: detete, just used for review
-        # Date tag
-        date_tag = time.strftime("%d-%m-%y-%H%M%S")
-
-        # adding the info to the SavedModel path
-        out_savedmodel = None if params.out_savedmodel is None else \
-            os.path.join(params.out_savedmodel, expe_name + date_tag)
-
-        # Scaling batch size and learning rate accordingly to number of workers
-        batch_size_train = params.batch_size_train * n_workers
-        batch_size_valid = params.batch_size_valid * n_workers
-        learning_rate = params.learning_rate * n_workers
-
-        logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)",
-                     learning_rate, batch_size_train, n_workers)
-
-        # Datasets
-        tfrecord_train = TFRecords(params.training_record) if params.training_record else None
-        tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
-
-        # Model instantiation
-        model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
-
-        # TF.dataset-s instantiation
-        tf_ds_train = tfrecord_train.read(batch_size=batch_size_train,
-                                          target_keys=model.model_output_keys,
-                                          n_workers=n_workers,
-                                          shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None
-        tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid,
-                                     target_keys=model.model_output_keys,
-                                     n_workers=n_workers) for tfrecord in tfrecord_valid_array]
-
-        with strategy.scope():
-            # Creating the Keras network corresponding to the model
-            model.create_network()
-
-            # Metrics
-            metrics_list = [metrics.MeanSquaredError(), metrics.PSNR()]
-            if params.all_metrics:
-                metrics_list += [metrics.StructuralSimilarity(), metrics.SpectralAngle()]  # A bit slow to compute
-
-            # Creating the model or loading it from checkpoints
-            logging.info("Loading model \"%s\"", params.model)
-            model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
-                          loss=model.get_loss(),
-                          metrics={out_key: metrics_list for out_key in model.model_output_keys})
-            model.summary(strategy)
-
-            if params.plot_model:
-                model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy)
-
-            callbacks = []
-            # Define the checkpoint callback
-            if params.ckpt_dir:
-                if params.strategy == 'singlecpu':
-                    logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
-                else:
-                    backup_dir = os.path.join(params.ckpt_dir, params.model)
-                    # Backup (deleted once the model is trained the specified number of epochs)
-                    callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
-                    # Persistent save (still here after the model is trained)
-                    callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
-
-            # Define the Keras TensorBoard callback.
-            logdir = None
-            if params.logdir:
-                logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
-                tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
-                                                                   profile_batch=params.profiling)
-                callbacks.append(tensorboard_callback)
-
-                # Define the previews callback
-                if params.previews:
-                    # We run the preview on an arbitrary sample of the validation dataset
-                    sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys)
-                    previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
-                                                         target_keys=model.model_output_keys)
-                    callbacks.append(previews_callback)
-
-            # Validation on multiple datasets
-            if tf_ds_valid:
-                additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir)
-                callbacks.append(additional_validation_callback)
-
-            # Save best checkpoint only
-            if params.save_best:
-                callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True,
-                                                                 monitor=params.save_best_ref, mode='min'))
-
-            # Early stopping if the training stops improving
-            if params.early_stopping:
-                callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001,
-                                                               patience=10, mode='min'))
-
-            # Training
-            model.fit(tf_ds_train,
-                      epochs=params.epochs,
-                      validation_data=tf_ds_valid[0] if tf_ds_valid else None,
-                      callbacks=callbacks,
-                      verbose=1 if params.verbose else 2)
-
-            # Multiworker training tries to save the model multiple times and this can create corrupted models
-            # Thus we save the model at the final path only for the 'chief' worker
-            if params.strategy != 'singlecpu':
-                if not _is_chief(strategy):
-                    out_savedmodel = None
-
-            # Export SavedModel
-            if out_savedmodel and not params.save_best:
-                logging.info("Saving SavedModel in %s", out_savedmodel)
-                model.save(out_savedmodel)
+    # Date tag
+    date_tag = time.strftime("%d-%m-%y-%H%M%S")
+
+    # adding the info to the SavedModel path
+    out_savedmodel = None if params.out_savedmodel is None else \
+        os.path.join(params.out_savedmodel, expe_name + date_tag)
+
+    # Scaling batch size and learning rate accordingly to number of workers
+    batch_size_train = params.batch_size_train * n_workers
+    batch_size_valid = params.batch_size_valid * n_workers
+    learning_rate = params.learning_rate * n_workers
+
+    logging.info("Learning rate was scaled to %s, effective batch size is %s (%s workers)",
+                 learning_rate, batch_size_train, n_workers)
+
+    # Datasets
+    tfrecord_train = TFRecords(params.training_record) if params.training_record else None
+    tfrecord_valid_array = [TFRecords(rep) for rep in params.valid_records]
+
+    # Model instantiation
+    model = ModelFactory.get_model(params.model, dataset_shapes=tfrecord_train.output_shape)
+
+    # TF.dataset-s instantiation
+    tf_ds_train = tfrecord_train.read(batch_size=batch_size_train,
+                                      target_keys=model.model_output_keys,
+                                      n_workers=n_workers,
+                                      shuffle_buffer_size=params.shuffle_buffer_size) if tfrecord_train else None
+    tf_ds_valid = [tfrecord.read(batch_size=batch_size_valid,
+                                 target_keys=model.model_output_keys,
+                                 n_workers=n_workers) for tfrecord in tfrecord_valid_array]
+
+    with strategy.scope():
+        # Creating the Keras network corresponding to the model
+        model.create_network()
+
+        # Metrics
+        metrics_list = [metrics.MeanSquaredError, metrics.PSNR]
+        if params.all_metrics:
+            metrics_list += [metrics.StructuralSimilarity, metrics.SpectralAngle]  # A bit slow to compute
+
+        # Creating the model or loading it from checkpoints
+        logging.info("Loading model \"%s\"", params.model)
+        model.compile(
+            optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
+            loss=model.get_loss(),
+            metrics={
+                out_key: metric()
+                for out_key in model.model_output_keys
+                for metric in metrics_list
+            }
+        )
+        model.summary(strategy)
+
+        if params.plot_model:
+            model.plot('/tmp/model_architecture_{}.png'.format(model.__class__.__name__), strategy)
+
+        callbacks = []
+        # Define the checkpoint callback
+        if params.ckpt_dir:
+            if params.strategy == 'singlecpu':
+                logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
+            else:
+                backup_dir = os.path.join(params.ckpt_dir, params.model)
+                # Backup (deleted once the model is trained the specified number of epochs)
+                callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
+                # Persistent save (still here after the model is trained)
+                callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
+
+        # Define the Keras TensorBoard callback.
+        logdir = None
+        if params.logdir:
+            logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
+            tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
+                                                               profile_batch=params.profiling)
+            callbacks.append(tensorboard_callback)
+
+            # Define the previews callback
+            if params.previews:
+                # We run the preview on an arbitrary sample of the validation dataset
+                sample = tfrecord_valid_array[0].read_one_sample(target_keys=model.model_output_keys)
+                previews_callback = PreviewsCallback(sample, logdir, input_keys=model.dataset_input_keys,
+                                                     target_keys=model.model_output_keys)
+                callbacks.append(previews_callback)
+
+        # Validation on multiple datasets
+        if tf_ds_valid:
+            additional_validation_callback = AdditionalValidationSets(tf_ds_valid[1:], logdir)
+            callbacks.append(additional_validation_callback)
+
+        # Save best checkpoint only
+        if params.save_best:
+            callbacks.append(keras.callbacks.ModelCheckpoint(params.out_savedmodel, save_best_only=True,
+                                                             monitor=params.save_best_ref, mode='min'))
+
+        # Early stopping if the training stops improving
+        if params.early_stopping:
+            callbacks.append(keras.callbacks.EarlyStopping(monitor=params.save_best_ref, min_delta=0.0001,
+                                                           patience=10, mode='min'))
+
+        # Training
+        model.fit(tf_ds_train,
+                  epochs=params.epochs,
+                  validation_data=tf_ds_valid[0] if tf_ds_valid else None,
+                  callbacks=callbacks,
+                  verbose=1 if params.verbose else 2)
+
+        # Multiworker training tries to save the model multiple times and this can create corrupted models
+        # Thus we save the model at the final path only for the 'chief' worker
+        if params.strategy != 'singlecpu':
+            if not _is_chief(strategy):
+                out_savedmodel = None
+
+        # Export SavedModel
+        if out_savedmodel and not params.save_best:
+            logging.info("Saving SavedModel in %s", out_savedmodel)
+            model.save(out_savedmodel)
 
 
 if __name__ == "__main__":
diff --git a/setup.py b/setup.py
index ec9843b927c4ad69034de0b3a72905f199fd3038..83446ec85865532a4da1a06579378ed5c2b4fad1 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 """
-Copyright (c) 2020-2022 INRAE
+Copyright (c) 2020-2023 INRAE
 
 Permission is hereby granted, free of charge, to any person obtaining a
 copy of this software and associated documentation files (the "Software"),
@@ -28,7 +28,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
 
 setuptools.setup(
     name="decloud",
-    version="1.3",
+    version="1.4",
     author="Remi Cresson, Nicolas Narçon, Benjamin Commandre",
     author_email="remi.cresson@inrae.fr",
     description="Deep learning based reconstruction of optical time series using SAR imagery",