diff --git a/Moringa2DL.py b/Moringa2DL.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4738ff1af2cdab51ff2c1c4680f93383bcde9a2
--- /dev/null
+++ b/Moringa2DL.py
@@ -0,0 +1,297 @@
+import os
+from osgeo import gdal
+import subprocess
+import glob
+import numpy as np
+import mtdUtils
+import xml.etree.ElementTree as ET
+#import otbApplication as otb
+
+def generateAll(feat_fld,reference,seg_src,cfield,ofld='.',dstype='train'):
+    image_reference = sorted(glob.glob(os.path.join(feat_fld,'*','*_FEAT.tif')))[0]
+    # gt_rnn_feat = None
+    # gt_rnn_label = np.zeros(0)
+    # gt_cnn_feat = None
+    # gt_cnn_label = np.zeros(0)
+
+    # Subprocess version
+    stat_file = ofld + os.sep + os.path.splitext(os.path.basename(reference))[0] + '_polystat_'+cfield+'.xml'
+    cmd = ['otbcli_PolygonClassStatistics', '-in', image_reference, '-vec', reference, '-field', cfield, '-out',
+           stat_file]
+    print(cmd)
+    subprocess.call(cmd)
+
+    samples_file = ofld + os.sep + os.path.splitext(os.path.basename(reference))[0] + '_sample_loc_'+cfield+'.shp'
+    cmd = ['otbcli_SampleSelection', '-in', image_reference, '-vec', reference, '-field', cfield, '-out',
+           samples_file, '-instats', stat_file, '-strategy', 'total', '-strategy.total.v', '4400']
+    subprocess.call(cmd)
+
+    gt_rnn_feat, gt_rnn_label = generateAnnotatedDataset(feat_fld, samples_file, cfield,sname=os.path.splitext(os.path.basename(samples_file))[0]+'_rnn_'+cfield,ofld=ofld, do_norm=False)
+    # if gt_rnn_feat is None : 
+    #     gt_rnn_feat = np.zeros((0,rnn_feat.shape[1]))
+    # np.concatenate((gt_rnn_feat,rnn_feat))
+    # np.concatenate((gt_rnn_label,rnn_label))
+    gt_cnn_feat, gt_cnn_label = CNNAnnotedDataset(seg_src, samples_file, cfield, 25, sname=os.path.splitext(os.path.basename(samples_file))[0]+'_cnn_'+cfield, ofld=ofld)
+
+    np.save(os.path.join(ofld,dstype+'_'+cfield+'_rnn.npy'),gt_rnn_feat)
+    np.save(os.path.join(ofld,dstype+'_'+cfield+'_cnn.npy'),gt_cnn_feat)
+    np.save(os.path.join(ofld,dstype+'_'+cfield+'_label.npy'),gt_cnn_label)
+    return [os.path.join(ofld,dstype+'_'+cfield+'_rnn.npy'),os.path.join(ofld,dstype+'_'+cfield+'_cnn.npy'),os.path.join(ofld,dstype+'_'+cfield+'_label.npy')]
+
+def generateAllForClassify(feat_fld,shp,seg_src,feats,cfield='Segment_ID',ofld='.'):
+    import geopandas as gp
+    basename = os.path.splitext(os.path.basename(shp))[0]
+    print(os.path.join(ofld,basename+'_'+cfield+'_rnn.npy'))
+    if os.path.exists(os.path.join(ofld,basename+'_'+cfield+'_rnn.npy')) == False or os.path.exists(os.path.join(ofld,basename+'_'+cfield+'_cnn.npy')) ==False or os.path.exists(os.path.join(ofld,basename+'_'+cfield+'_id.npy')) ==False :
+        image_reference = seg_src
+
+        stat_file = ofld + os.sep + os.path.splitext(os.path.basename(shp))[0] + '_polystat_'+cfield+'.xml'
+        cmd = ['otbcli_PolygonClassStatistics', '-in', image_reference, '-vec', shp, '-field', cfield, '-out',
+               stat_file]
+        subprocess.call(cmd)
+
+        samples_file = ofld + os.sep + 'sample_loc_' + os.path.splitext(os.path.basename(shp))[0] + cfield+'.shp'
+        cmd = ['otbcli_SampleSelection', '-in', image_reference, '-vec', shp, '-field', cfield, '-out',
+               samples_file, '-instats', stat_file, '-strategy', 'constant','-strategy.constant.nb','1', '-sampler' ,'random']
+        proc = subprocess.run(cmd, stdout=subprocess.PIPE)
+        mtdUtils.cleanEmptySegment(shp,proc.stdout)
+
+        samples = gp.read_file(samples_file)
+        samples_feats = samples.filter(feats, axis=1)
+        rnn_feat = samples_feats.as_matrix()
+        join_id = samples.filter([cfield],axis=1)
+        id_geom = join_id.as_matrix()
+
+        cnn_feat, cnn_label = CNNAnnotedDataset(seg_src, samples_file, cfield, 25, sname=os.path.splitext(os.path.basename(samples_file))[0]+'_cnn', ofld=ofld)
+
+        np.save(os.path.join(ofld,basename+'_'+cfield+'_rnn.npy'),rnn_feat)
+        np.save(os.path.join(ofld,basename+'_'+cfield+'_cnn.npy'),cnn_feat)
+        np.save(os.path.join(ofld,basename+'_'+cfield+'_id.npy'),id_geom)
+    return [os.path.join(ofld,basename+'_'+cfield+'_rnn.npy'),os.path.join(ofld,basename+'_'+cfield+'_cnn.npy'),os.path.join(ofld,basename+'_'+cfield+'_id.npy')]
+
+def generateAnnotatedDataset(fld, sampling, cl_field, sname='AnnotatedSet', ofld='.', do_norm=True):
+    imfeat = sorted(glob.glob(os.path.join(fld,'*','*_FEAT.tif')))
+
+    cmd = ['otbcli_Rasterization', '-in', sampling, '-im', imfeat[0], '-mode', 'attribute', '-mode.attribute.field', cl_field, '-out', ofld + os.sep + sname + '_gt.tif', 'uint16']
+    subprocess.call(cmd)
+
+    ds = gdal.Open(ofld + os.sep + sname + '_gt.tif')
+    gt = ds.ReadAsArray()
+    ds = None
+
+    sel = np.where(gt > 0)
+    lst = []
+    ts_size = None
+
+    for fn in imfeat:
+        ds = gdal.Open(fn)
+        if ts_size is None:
+            ts_size = ds.RasterCount
+        img = ds.ReadAsArray()
+        ds = None
+        lst.append(np.moveaxis(img, 0, 2)[sel])
+
+    feat = np.concatenate(tuple(lst), axis=1)
+    feat[np.where(np.isnan(feat))] = 0
+
+    if do_norm is True:
+        for i in range(ts_size):
+            t_min, t_max = np.min(feat[:, i::ts_size]), np.max(feat[:, i::ts_size])
+            feat[:, i::ts_size] = (feat[:, i::ts_size] - t_min) / (t_max - t_min)
+    labl = gt[sel]
+
+    # np.save(ofld + os.sep + sname + '_feat.npy', feat)
+    # np.save(ofld + os.sep + sname + '_labl.npy', labl)
+
+    return feat, labl
+
+def generateMultiSourceAnnotatedDataset(source_list, patch_sizes, reference, cl_field, image_reference, sampling_strategy='smallest', sname='AnnotatedSet', ofld='.', do_norm=True):
+
+    lbl = True
+
+    # Subprocess version
+    stat_file = ofld + os.sep + os.path.splitext(os.path.basename(reference))[0] + '_polystat.xml'
+    cmd = ['otbcli_PolygonClassStatistics', '-in', image_reference, '-vec', reference, '-field', cl_field, '-out',
+           stat_file]
+    subprocess.call(cmd)
+
+    # Subprocess version
+    samples_file = ofld + os.sep + os.path.splitext(os.path.basename(reference))[0] + '_sample_loc.shp'
+    cmd = ['otbcli_SampleSelection', '-in', image_reference, '-vec', reference, '-field', cl_field, '-out',
+           samples_file, '-instats', stat_file, '-strategy', sampling_strategy]
+    subprocess.call(cmd)
+
+    for src,psz in zip(source_list,patch_sizes):
+        mns = []
+        sds = []
+
+        cmd = ['otbcli_PatchesExtraction']
+        n_timestamps = 1
+        if os.path.isdir(src):
+            imfeat = sorted(glob.glob(src + os.sep + '*_FEAT.tif'))
+            cmd += ['-source1.il'] + imfeat
+            n_timestamps = len(imfeat)
+            mns,sds = computeGlobalMeanStd(imfeat)
+        else:
+            cmd += ['-source1.il', src]
+            mns, sds = computeGlobalMeanStd(src)
+
+        cmd += ['-source1.patchsizex', str(psz), '-source1.patchsizey', str(psz), '-vec', samples_file, '-field', cl_field]
+        if lbl:
+            labels_file = ofld + os.sep + sname + '_sample_labels.tif'
+            cmd += ['-outlabels', labels_file, 'uint16']
+            lbl = False
+
+        patch_file = ofld + os.sep + sname + '_' + os.path.splitext(os.path.basename(src))[0] + '_samples.tif'
+        cmd += ['-source1.out', patch_file, 'float']
+        subprocess.call(cmd)
+
+        ds = gdal.Open(patch_file)
+        arr = ds.ReadAsArray()
+        ds = None
+
+        if n_timestamps == 1:
+
+            if do_norm:
+                # band-by-band
+                '''
+                for b in range(arr.shape[0]):
+                    b_min, b_max = np.min(arr[b]), np.max(arr[b])
+                    arr[b] = (arr[b] - b_min) / (b_max - b_min)
+                '''
+                # all bands
+                '''
+                b_min, b_max = np.min(arr), np.max(arr)
+                arr = (arr - b_min) / (b_max - b_min)
+                '''
+                # global standardization
+                for b in range(arr.shape[0]):
+                    arr[b] = (arr[b] - mns[b]) / sds[b]
+
+            arr = np.reshape(np.moveaxis(arr, 0, -1), (int(arr.shape[1] / arr.shape[2]), arr.shape[2], arr.shape[2], arr.shape[0]))
+            np.save(patch_file.replace('.tif','.npy'),arr)
+
+        else:
+
+            ts_size = int(arr.shape[0] / n_timestamps)
+            arr = np.moveaxis(arr, 0, -1).squeeze()
+            arr = np.nan_to_num(arr)
+
+            if do_norm:
+                # min-max bandwise scaling
+                '''
+                for i in range(ts_size):
+                    t_min, t_max = np.min(arr[:, i::ts_size]), np.max(arr[:, i::ts_size])
+                    arr[:, i::ts_size] = (arr[:, i::ts_size] - t_min) / (t_max - t_min)
+                '''
+                # bandwise standardization
+                for i in range(ts_size):
+                    arr[:, i::ts_size] = (arr[:, i::ts_size] - mns[i]) / sds[i]
+            np.save(patch_file.replace('.tif', '.npy'), arr)
+
+        ds = gdal.Open(labels_file)
+        arr = ds.ReadAsArray().squeeze()
+        ds = None
+        np.save(labels_file.replace('.tif', '.npy'), arr)
+
+    return
+
+def computeGlobalMeanStd(sources):
+
+    timeseries = True
+    if not isinstance(sources,list):
+        sources = [sources]
+        timeseries = False
+
+    ds = gdal.Open(sources[0])
+    ndv = ds.GetRasterBand(1).GetNoDataValue()
+    ds = None
+    print('Using no-data value : ' + str(ndv))
+
+    nm = os.path.splitext(sources[0])[0] + '.stats.xml'
+    if not os.path.exists(nm):
+        cmd = ['otbcli_ComputeImagesStatistics','-il'] + sources + ['-out', nm]
+        if ndv is not None:
+            cmd += ['-bv', str(ndv)]
+        subprocess.call(cmd)
+
+    vals = []
+    root = ET.parse(nm).getroot()
+    for item in root.findall('Statistic'):
+        for child in item:
+            vals.append(float(child.attrib['value']))
+    mns = np.array(vals[:int(len(vals)/2)])
+    sds = np.array(vals[int(len(vals)/2):])
+
+    """
+    if timeseries:
+        tmns = []
+        tsds = []
+        for i in range(len(sources)):
+            tmns.append(np.mean(mns[i::B]))
+        tmns = np.array(tmns)
+        for i in range(len(sources)):
+            tsds.append(np.sqrt((np.sum(np.power(sds[i::B],2)) + np.sum(np.power(mns[i::B]-tmns[i],2)))/B))
+        mns = np.array(tmns)
+        sds = np.array(sds)
+    """
+
+    return mns,sds
+
+def CNN_preprocessing(image, gt, size, sname,ofld):
+
+    patch_size = size
+    valid = int(patch_size/2)
+
+
+
+    image = np.moveaxis(image, 0,2)
+#Bordo
+    gt[0:valid,:] = 0
+    gt[len(gt)-valid:len(gt),:] = 0
+    gt[:,0:valid]=0
+    gt[:,len(gt)-valid : len(gt)] = 0
+    sel = np.where(gt > 0)
+
+    x, y = sel
+    lst = image[sel]
+
+    lista = []
+    patch = []
+    labels = []
+
+
+    for n in range(len(x)):
+
+        begin_i = x[n] - int(patch_size / 2)
+        end_i = x[n] + int(patch_size / 2)+1
+        begin_j = y[n] - int(patch_size / 2)
+        end_j = y[n] + int(patch_size / 2)+1
+
+        patch = image[begin_i: end_i, begin_j:end_j,:]
+        lista.append(patch)
+        labels.append(gt[x[n],y[n]])
+
+
+
+
+    out= np.array(lista)
+    labels= np.array(labels)
+    # np.save(ofld + os.sep + sname + '_feat.npy', out)
+    # np.save(ofld + os.sep + sname + '_labl.npy', labels)
+
+    return out,labels
+
+def CNNAnnotedDataset(VHR, sampling, cl_field, size, sname='AnnotatedSet', ofld='.'):
+    ds = gdal.Open(VHR)
+    vhr_array = ds.ReadAsArray()
+    ds = None
+
+    cmd = ['otbcli_Rasterization', '-in', sampling, '-im', VHR, '-mode', 'attribute', '-mode.attribute.field', cl_field, '-out', ofld + os.sep + sname + '_gt.tif', 'uint16']
+    subprocess.call(cmd)
+
+    ds = gdal.Open(ofld + os.sep + sname + '_gt.tif')
+    labels = ds.ReadAsArray()
+
+    feat, labl= CNN_preprocessing(vhr_array,labels,size,sname,ofld)
+    return feat, labl
\ No newline at end of file
diff --git a/TFmodels.py b/TFmodels.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b7f3e8c502515be94caa7a176a35d4f1e3089f7
--- /dev/null
+++ b/TFmodels.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python3
+#-*- coding: utf-8 -*-
+
+# =========================================================================
+#   Program:   moringa
+#
+#   Copyright (c) TETIS. All rights reserved.
+#
+#   See LICENSE for details.
+#
+#   This software is distributed WITHOUT ANY WARRANTY; without even
+#   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+#   PURPOSE.  See the above copyright notices for more information.
+#
+# =========================================================================
+
+import tensorflow as tf
+
+# n_timestamps = 21
+# ts_size = 16
+# patch_size = 25
+# n_bands = 4
+
+def retrieveTSinfo(feats):
+    nb_comp = 0
+    nb_date = 0
+    nb_T= max([int(x.split('F')[0].split('T')[-1]) for x in feats])
+    for i in range(1,nb_T+1):
+        nb_comp += max([int(x.split('D')[0].split('F')[-1]) for x in feats])
+        nb_date += max([ int(x.split('m')[0].split('D')[-1]) for x in feats])
+    return nb_T,nb_comp,nb_date
+
+class M3FusionModel(tf.keras.Model):
+    
+    def __init__(self,n_classes,n_timestamps,ts_size,patch_size,n_bands):
+        super(M3FusionModel, self).__init__(n_classes,n_timestamps,ts_size,patch_size,n_bands)
+        # Convolutions
+        #RNN
+        self.input_ts = tf.keras.layers.Input(shape=(n_timestamps,ts_size),name='timeseries_input')
+        #rnn_out = keras.layers.GRU(512,return_sequences=True,name='gru_base')(input_ts)
+        self.rnn_out1 = tf.keras.layers.GRU(512,name='gru_base')
+        self.rnn_out2 = tf.keras.layers.Dropout(rate=0.5,name='gru_dropout')
+        #rnn_out = BasicAttention(name='gru_attention')(rnn_out)
+        self.rnn_aux = tf.keras.layers.Dense(n_classes,activation='softmax',name='rnn_dense_layer_'+str(n_classes))
+
+        #CNN
+        self.input_vhr = tf.keras.layers.Input(shape=(patch_size,patch_size,n_bands), name='vhr_input')
+        self.cnn1_1 = tf.keras.layers.Conv2D(256,[7,7],activation='relu',name='cnn_conv1')
+        self.cnn1_2 = tf.keras.layers.BatchNormalization(name='cnn_conv1_bn')
+        self.cnn1_3 = tf.keras.layers.MaxPooling2D(strides=(2,2),name='cnn_conv1_pool')
+        self.cnn2_1 = tf.keras.layers.Conv2D(512,[3,3],activation='relu',name='cnn_conv2')
+        self.cnn2_2 = tf.keras.layers.BatchNormalization(name='cnn_conv2_bn')
+        self.cnn3_1 = tf.keras.layers.Conv2D(512,[3,3],activation='relu',padding='same',name='cnn_conv3')
+        self.cnn3_2 = tf.keras.layers.BatchNormalization(name='cnn_conv3_bn')
+        self.cnn4_1 = tf.keras.layers.Concatenate(axis=3,name='cnn_inner_concat')
+        self.cnn4_2 = tf.keras.layers.Conv2D(512,[1,1],activation='relu',name='cnn_conv4')
+        self.cnn4_3 = tf.keras.layers.BatchNormalization(name='cnn_conv4_bn')
+        self.cnn_out = tf.keras.layers.GlobalAveragePooling2D(name='cnn_conv4_gpool')
+        self.cnn_aux = tf.keras.layers.Dense(n_classes,activation='softmax',name='cnn_dense_layer_'+str(n_classes))
+
+        #Merge
+        self.classifier = tf.keras.layers.Concatenate(axis=-1,name='rnn_cnn_merge')
+        self.classifier_out = tf.keras.layers.Dense(n_classes,activation='softmax',name='full_dense_layer_'+str(n_classes))
+
+    def call(self, inputs):
+        ts = inputs[0]
+        image = inputs[1]
+        self.input_ts=ts
+        self.input_vhr=tf.cast(image, tf.float32)
+        rnn_out = self.rnn_out1(self.input_ts)
+        rnn_out = self.rnn_out2(rnn_out)
+        rnn_aux = self.rnn_aux(rnn_out)
+
+        cnn1 = self.cnn1_1(self.input_vhr)
+        cnn1 = self.cnn1_2(cnn1)
+        cnn1 = self.cnn1_3(cnn1)
+        cnn2 = self.cnn2_1(cnn1)
+        cnn2 = self.cnn2_2(cnn2)
+        cnn3 = self.cnn3_1(cnn2)
+        cnn3 = self.cnn3_2(cnn3)
+        cnn4 = self.cnn4_1([cnn2,cnn3])
+        cnn4 = self.cnn4_2(cnn4)
+        cnn4 = self.cnn4_3(cnn4)
+        cnn_out = self.cnn_out(cnn4)
+        cnn_aux = self.cnn_aux(cnn_out)
+
+        classifier = self.classifier([rnn_out,cnn_out])
+        classifier_out = self.classifier_out(classifier)
+        return classifier_out, rnn_aux, cnn_aux 
+
diff --git a/classificationWorkflow.py b/classificationWorkflow.py
index a23789bd20864a928b02106ccd7dc9bbcf020be1..0ff55ff526250f82f43463e88c4e250eb61565f4 100644
--- a/classificationWorkflow.py
+++ b/classificationWorkflow.py
@@ -250,7 +250,7 @@ def deepTraining(data,code,model_fld,params,feat,feat_mode = 'list', epochs=20):
 
     if params[1] == 'standard':
         #Read training shp with geopandas
-        ds = geopandas.read_file(data)
+        ds = geopandas.read_file(data[0])
         #Extract feats and 
         feats = ds.filter(feat,axis=1)
         targets = ds[code]
@@ -284,8 +284,7 @@ def deepTraining(data,code,model_fld,params,feat,feat_mode = 'list', epochs=20):
         # weights_path = model_file.replace('.','_weights.')
         log_path = model_file.replace('.h5','.log')
         log = CSVLogger(log_path, separator=',', append=False)
-        ts_size, n_timestamps = TFmodels.retrieveTSinfo(feat)
-        print(ts_size, n_timestamps)
+        nb_T, ts_size, n_timestamps = TFmodels.retrieveTSinfo(feat)
         patch_size = 25
         n_bands = 4
 
@@ -318,6 +317,10 @@ def deepTraining(data,code,model_fld,params,feat,feat_mode = 'list', epochs=20):
         elif len(data) == 6:
             x_valid_rnn = np.load(data[3])
             x_valid_rnn = scaler_rnn.transform(x_valid_rnn)
+            # WARNING : I think there is something to change about the TxDxmx fields use. For the moment Time series type are divide in different part but I don't think rnn is capable to understand that, isn't it ?
+            # On one hand, we will have one composite time series with Sentinel and Venus/landsat data, with the good dates (especially with artificial gapfilled dates creation) and a number of feats equal of the sum of the feats of each time series
+            # On the other hand, we keep the Timeseries divide, but it will generate dates in disorder (dates of the first series, then dates of the second)
+            # On a hypotical third hand, maybe we can imagine generate has many rnn branch that there is type of times series, and merge them before merging rnn and cnn branch ?
             x_valid_rnn = np.reshape(x_valid_rnn, (x_valid_rnn.shape[0], n_timestamps, -1))
             x_valid_cnn = np.load(data[4])
             x_valid_cnn = scaler_cnn.transform(x_valid_cnn.reshape(-1,patch_size*patch_size*n_bands)).reshape(-1,patch_size,patch_size,n_bands)
@@ -346,16 +349,18 @@ def deepClassify(shp_list,code,scalers,model_file, csv_classes,out_fld,out_ext,f
     import TFmodels
     import joblib
     import csv
-    
+    #load scaler and model
     if len(scalers)== 1:
         method = 'standard'
         scaler = joblib.load(scalers[0])
         model = tf.keras.models.load_model(model_file)
+    #M3Fusion has 2 scaler, one for time series, one for the vhr patches list
     elif len(scalers)== 2:
         method = 'M3Fusion'
-        scaler_rnn = joblib.load(scalers[1])
-        scaler_cnn = joblib.load(scalers[0])
+        scaler_rnn = joblib.load(scalers[0])
+        scaler_cnn = joblib.load(scalers[1])
 
+    #Get classes equivalent
     dict_classes = {}
     with open(csv_classes,'r') as file:
         rdr = csv.reader(file)
@@ -366,9 +371,11 @@ def deepClassify(shp_list,code,scalers,model_file, csv_classes,out_fld,out_ext,f
     if method == 'standard' :
         for shp in shp_list:
             out_file = os.path.join(out_fld, os.path.basename(shp).replace('.shp', out_ext + '.shp'))
+            #Load segmentation shp and retrieve time series feats
             ds = geopandas.read_file(shp)
             feats = ds.filter(feat,axis=1)
             cfeats = scaler.transform(feats)
+            #Classify
             predict = model.predict(cfeats)
             class_predict = [dict_classes[np.where(p==np.max(p))[0][0]] for p in predict]
             out = ds.filter(items=['Segment_ID',code[1:],'geometry'])
@@ -376,13 +383,15 @@ def deepClassify(shp_list,code,scalers,model_file, csv_classes,out_fld,out_ext,f
             if compute_confidence:
                 confmap = [np.max(p) for p in predict]
                 out.insert(out.shape[1]-1,'confidence',confmap)
+            #Save output
             out.to_file(out_file)
             out_file_list.append(out_file)
     elif method == 'M3Fusion':
         import Moringa2DL
-        ts_size, n_timestamps = TFmodels.retrieveTSinfo(feat)
+        nb_T, ts_size, n_timestamps = TFmodels.retrieveTSinfo(feat)
         patch_size = 25
         n_bands = 4
+        #Reload custom model
         model = TFmodels.M3FusionModel(len(dict_classes),n_timestamps,ts_size,patch_size,n_bands)
         opt = tf.keras.optimizers.Adam(learning_rate=0.0002)
         x1= np.ndarray([1,n_timestamps,ts_size])
@@ -396,20 +405,20 @@ def deepClassify(shp_list,code,scalers,model_file, csv_classes,out_fld,out_ext,f
               )
         for shp in shp_list:
             out_file = os.path.join(out_fld, os.path.basename(shp).replace('.shp', out_ext + '.shp'))
+            #Generate rnn and cnn feats
             rnn, cnn, id_geom = Moringa2DL.generateAllForClassify(feat_folder,shp,seg_src,feat,cfield='Segment_ID',ofld=os.path.dirname(shp))
-            ds = geopandas.read_file(shp)
             rnn_feats = np.load(rnn)
-            if len(ds) != len(rnn_feats):
-                sys.exit("different number of elements between {} and {}".format(shp,rnn))
             rnn_feats = scaler_rnn.transform(rnn_feats)
             rnn_feats = np.reshape(rnn_feats, (rnn_feats.shape[0], n_timestamps, -1))
             cnn_feats = np.load(cnn)
             cnn_feats = scaler_cnn.transform(cnn_feats.reshape(-1,patch_size*patch_size*n_bands)).reshape(-1,patch_size,patch_size,n_bands)
+            #Classify
             predict = model.predict([rnn_feats,cnn_feats])
             class_predict = np.array([dict_classes[np.where(p==np.max(p))[0][0]] for p in predict[0]])
             id_join = np.load(id_geom)
             stack = np.column_stack([id_join,class_predict])
             predict_dict={}
+            #Join predict on Segment_ID
             for i in np.unique(id_join):
                 predict_dict[int(i)]=[]
             for i,j in stack:
@@ -417,6 +426,9 @@ def deepClassify(shp_list,code,scalers,model_file, csv_classes,out_fld,out_ext,f
             classif = {}
             for i,j in predict_dict.items():
                 classif[i]=np.bincount(j).argmax()
+            ds = geopandas.read_file(shp)
+            if len(ds) != len(classif):
+                sys.exit("different number of elements between {} and {}".format(shp,rnn))
             if validation == False :
                 out = ds.filter(items=['Segment_ID','geometry'])
             else :
@@ -436,6 +448,7 @@ def deepClassify(shp_list,code,scalers,model_file, csv_classes,out_fld,out_ext,f
                     confidence[i]=np.mean(j)
                 df = geopandas.GeoDataFrame(np.array(list(confidence.items())),columns=['id','confidence'])
                 out = out.join(df.set_index('id'))
+            #Save output
             out.to_file(out_file)
             out_file_list.append(out_file)
     return out_file_list
diff --git a/launchChain.py b/launchChain.py
index cecd95382286ad6161d75ff223cc4ba052351d8a..e32e3869177a6de6a603b958cc2d7642826f0d17 100644
--- a/launchChain.py
+++ b/launchChain.py
@@ -557,105 +557,105 @@ def main(argv):
         # generating GT samples and merging tiles
 
         params = config.get('TRAINING CONFIGURATION', 'parameters').split(' ')
-        if '-classifier' in params or ('-tensorflow' in params and 'standard' in params) :
-            if os.path.exists(samples_fld + '/GT_samples.shp'):
-                warnings.warn('File ' + samples_fld + '/GT_samples.shp' + ' already exists, skipping zonal stats on reference.')
-            else:
-                gttl = sorted(glob.glob(seg_fld + '/GT_segmentation_*.shp'))
-                ogttl = []
+        # if '-classifier' in params or ('-tensorflow' in params and 'standard' in params) :
+        if os.path.exists(samples_fld + '/GT_samples.shp'):
+            warnings.warn('File ' + samples_fld + '/GT_samples.shp' + ' already exists, skipping zonal stats on reference.')
+        else:
+            gttl = sorted(glob.glob(seg_fld + '/GT_segmentation_*.shp'))
+            ogttl = []
 
-                ext_flds = getFieldNames(gttl[0])
+            ext_flds = getFieldNames(gttl[0])
 
-                # BEGIN ** GT_SAMPLES **
-                cmd_list = []
+            # BEGIN ** GT_SAMPLES **
+            cmd_list = []
 
-                for tl in gttl:
-                    tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
-                    for fn in tocp:
-                        shutil.copyfile(fn, samples_fld + '/' + os.path.basename(fn))
-                    otl = samples_fld + '/' + os.path.basename(tl)
-                    ogttl.append(otl)
-                    cmd_list.append(
-                        ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
-                         '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
+            for tl in gttl:
+                tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
+                for fn in tocp:
+                    shutil.copyfile(fn, samples_fld + '/' + os.path.basename(fn))
+                otl = samples_fld + '/' + os.path.basename(tl)
+                ogttl.append(otl)
+                cmd_list.append(
+                    ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
+                     '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
 
-                queuedProcess(cmd_list, N_processes=1, shell=sh)
-                # END ** GT_SAMPLES **
+            queuedProcess(cmd_list, N_processes=1, shell=sh)
+            # END ** GT_SAMPLES **
 
-                var_keys = [x for x in getFieldNames(ogttl[0]) if x not in ext_flds]
+            var_keys = [x for x in getFieldNames(ogttl[0]) if x not in ext_flds]
 
-                with open(samples_fld + '/field_names.csv', 'w') as varfile:
-                    vw = csv.writer(varfile)
-                    for x in zip(var_keys, feat_to_vars):
-                        vw.writerow(x)
+            with open(samples_fld + '/field_names.csv', 'w') as varfile:
+                vw = csv.writer(varfile)
+                for x in zip(var_keys, feat_to_vars):
+                    vw.writerow(x)
 
-                cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile']
+            cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile']
 
-                cmd = cmd_base + [samples_fld + '/GT_samples.shp', ogttl[0]]
+            cmd = cmd_base + [samples_fld + '/GT_samples.shp', ogttl[0]]
+            subprocess.call(cmd, shell=sh)
+
+            for k in range(1, len(ogttl)):
+                cmd = cmd_base + ['-update', '-append', samples_fld + '/GT_samples.shp', ogttl[k], '-nln', 'GT_samples']
                 subprocess.call(cmd, shell=sh)
 
-                for k in range(1, len(ogttl)):
-                    cmd = cmd_base + ['-update', '-append', samples_fld + '/GT_samples.shp', ogttl[k], '-nln', 'GT_samples']
-                    subprocess.call(cmd, shell=sh)
+            for otl in ogttl:
+                shpd.DeleteDataSource(otl)
+                os.remove(otl.replace('.shp','.tif'))
 
-                for otl in ogttl:
-                    shpd.DeleteDataSource(otl)
-                    os.remove(otl.replace('.shp','.tif'))
+        if ch_mode <= 0 or ch_mode == 2:
+            # generating VAL samples and merging tiles
+            gttl = sorted(glob.glob(seg_fld + '/VAL_segmentation_*.shp'))
+            ogttl = []
 
-            if ch_mode <= 0 or ch_mode == 2:
-                # generating VAL samples and merging tiles
-                gttl = sorted(glob.glob(seg_fld + '/VAL_segmentation_*.shp'))
-                ogttl = []
+            # BEGIN ** VAL_SAMPLES **
+            cmd_list = []
 
-                # BEGIN ** VAL_SAMPLES **
-                cmd_list = []
+            for tl in gttl:
+                tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
+                for fn in tocp:
+                    shutil.copyfile(fn, val_fld + '/' + os.path.basename(fn))
+                otl = val_fld + '/' + os.path.basename(tl)
+                ogttl.append(otl)
+                cmd_list.append(
+                    ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
+                     '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
 
-                for tl in gttl:
-                    tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
-                    for fn in tocp:
-                        shutil.copyfile(fn, val_fld + '/' + os.path.basename(fn))
-                    otl = val_fld + '/' + os.path.basename(tl)
-                    ogttl.append(otl)
-                    cmd_list.append(
-                        ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
-                         '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
+            queuedProcess(cmd_list, N_processes=1, shell=sh)
+            # END ** VAL_SAMPLES **
 
-                queuedProcess(cmd_list, N_processes=1, shell=sh)
-                # END ** VAL_SAMPLES **
+            cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile']
 
-                cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile']
+            cmd = cmd_base + [val_fld + '/VAL_samples.shp', ogttl[0]]
+            subprocess.call(cmd, shell=sh)
 
-                cmd = cmd_base + [val_fld + '/VAL_samples.shp', ogttl[0]]
+            for k in range(1, len(ogttl)):
+                cmd = cmd_base + ['-update', '-append', val_fld + '/VAL_samples.shp', ogttl[k], '-nln', 'VAL_samples']
                 subprocess.call(cmd, shell=sh)
 
-                for k in range(1, len(ogttl)):
-                    cmd = cmd_base + ['-update', '-append', val_fld + '/VAL_samples.shp', ogttl[k], '-nln', 'VAL_samples']
-                    subprocess.call(cmd, shell=sh)
-
-                for otl in ogttl:
-                    shpd.DeleteDataSource(otl)
-                    os.remove(otl.replace('.shp', '.tif'))
+            for otl in ogttl:
+                shpd.DeleteDataSource(otl)
+                os.remove(otl.replace('.shp', '.tif'))
 
-            if ch_mode > 0 or force_sample_generation is True:
-                # generating map samples (TBD)
-                gttl = sorted(glob.glob(seg_fld + '/segmentation_*.shp'))
-                ogttl = []
+        if ch_mode > 0 or force_sample_generation is True:
+            # generating map samples (TBD)
+            gttl = sorted(glob.glob(seg_fld + '/segmentation_*.shp'))
+            ogttl = []
 
-                # BEGIN ** SAMPLES **
-                cmd_list = []
+            # BEGIN ** SAMPLES **
+            cmd_list = []
 
-                for tl in gttl:
-                    tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
-                    for fn in tocp:
-                        shutil.copyfile(fn, test_fld + '/' + os.path.basename(fn))
-                    otl = test_fld + '/' + os.path.basename(tl)
-                    ogttl.append(otl)
-                    cmd_list.append(
-                        ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
-                         '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
+            for tl in gttl:
+                tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
+                for fn in tocp:
+                    shutil.copyfile(fn, test_fld + '/' + os.path.basename(fn))
+                otl = test_fld + '/' + os.path.basename(tl)
+                ogttl.append(otl)
+                cmd_list.append(
+                    ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
+                     '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
 
-                queuedProcess(cmd_list, N_processes=N_proc, shell=sh)
-            # END ** SAMPLES **
+            queuedProcess(cmd_list, N_processes=N_proc, shell=sh)
+        # END ** SAMPLES **
 
         if single_step:
             sys.exit("Single step mode. Exiting.")
@@ -718,7 +718,7 @@ def main(argv):
                             keepFields(val_out_tmp_list[0], val_out, ['Segment_ID', cfield, 'p' + cfield, 'confidence'])
                             shpd.DeleteDataSource(val_out_tmp_list[0])
                         elif '-tensorflow' in params:
-                            mfn, scaler, csv_classes = deepTraining(kfold_train_samples[i], cfield, mfld, params, var_list)
+                            mfn, scaler, csv_classes = deepTraining([kfold_train_samples[i]], cfield, mfld, params, var_list)
                             val_list = [kfold_test_samples[i]]
                             val_out = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp','_' + cfield + '.shp')
                             val_out_check = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp','_' + cfield + '_check.shp')
@@ -745,49 +745,18 @@ def main(argv):
                     training(samples_fld + '/GT_samples.shp',cfield,model_fld,params,var_list)
                 elif params[0] == '-tensorflow':
                     if params[1] == 'standard':
-                        deepTraining(samples_fld + '/GT_samples.shp',cfield,model_fld,params,var_list)
+                        if len(params)>1 :
+                            epochs=int(params[2])
+                            deepTraining([samples_fld + '/GT_samples.shp'],cfield,model_fld,params,var_list,epochs=epochs)
+                        else :
+                            deepTraining([samples_fld + '/GT_samples.shp'],cfield,model_fld,params,var_list)
                     elif params[1].lower() == 'm3fusion':
                         import Moringa2DL
                         epochs=int(params[2])
-                        tf_data = Moringa2DL.generateAll(feat_fld+'/*',reference,seg_src,cfield,ofld=samples_fld)
-                        # tf_data=['/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/GT_SAMPLES/train_code_1_rnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/GT_SAMPLES/train_code_1_cnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/GT_SAMPLES/train_code_1_label.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/VAL_SAMPLES/valid_code_1_rnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/VAL_SAMPLES/valid_code_1_cnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/VAL_SAMPLES/valid_code_1_label.npy']
-                        # gttl = sorted(glob.glob(seg_fld + '/GT_segmentation_*.shp'))
-                        # gt_rnn_feat = None
-                        # gt_rnn_label = np.zeros(0)
-                        # gt_cnn_feat = None
-                        # gt_cnn_label = np.zeros(0)
-                        # for gt in gttl :
-                        #     rnn_feat, rnn_label = Moringa2DL.generateAnnotatedDataset(feat_fld, gt, cfield,sname=os.path.splitext(os.path.basename(gt))[0]+'_rnn_'+cfield,ofld=samples_fld)
-                        #     if gt_rnn_feat is None : 
-                        #         gt_rnn_feat = np.zeros((0,rnn_feat.shape[1]))
-                        #     np.concatenate(gt_rnn_feat,rnn_feat)
-                        #     np.concatenate(gt_rnn_label,rnn_label)
-                        #     cnn_feat, cnn_label = Moringa2DL.CNNAnnotedDataset(seg_src, gt, cfield, 25, sname=os.path.splitext(os.path.basename(gt))[0]+'_cnn_'+cfield, ofld=samples_fld)
-                        #     if gt_cnn_feat is None : 
-                        #         gt_cnn_feat = np.zeros((0,cnn_feat.shape[1],cnn_feat.shape[2],cnn_feat.shape[3]))
-                        #     np.concatenate(gt_cnn_feat,cnn_feat)
-                        #     np.concatenate(gt_cnn_label,cnn_label)
-
-                        # tf_data.append(gt_rnn_feat)
-                        # tf_data.append(gt_rnn_label)
-                        # tf_data.append(gt_cnn_feat)
-                        # tf_data.append(gt_rnn_feat)
-
-                        val_rnn_feat = None
-                        val_rnn_label = None
+                        tf_data = Moringa2DL.generateAll(feat_fld,reference,seg_src,cfield,ofld=samples_fld)
                         if ch_mode == 0 or ch_mode == 2 :
-                            tf_valid_data = Moringa2DL.generateAll(feat_fld+'/*',validation,seg_src,cfield,ofld=val_fld,dstype='valid')
+                            tf_valid_data = Moringa2DL.generateAll(feat_fld,validation,seg_src,cfield,ofld=val_fld,dstype='valid')
                             tf_data += tf_valid_data
-                            # gttl = sorted(glob.glob(seg_fld + '/VAL_segmentation_*.shp'))
-                            # val_rnn_label = np.zeros(0)
-                            # for gt in gttl :
-                            #     rnn_feat, rnn_label = Moringa2DL.generateAnnotatedDataset(feat_fld, gt, cfield,sname=os.path.splitext(os.path.basename(gt))[0]+'_'+cfield,ofld=val_fld)
-                            #     if val_rnn_feat is None : 
-                            #         val_rnn_feat = np.zeros((0,rnn_feat.shape[1]))
-                            #     np.concatenate(val_rnn_feat,rnn_feat)
-                            #     np.concatenate(val_rnn_label,rnn_label)
-                            # tf_data.append(val_rnn_feat)
-                            # tf_data.append(val_rnn_label)
                         deepTraining(tf_data,cfield,model_fld,params,var_list,epochs=epochs)
                     else:
                         sys.exit("Tensorflow parameter {} is not available".format(params[1]))
@@ -877,12 +846,11 @@ def main(argv):
     stat_file = model_fld + '/GT_stats.xml' if '-tensorflow' not in params else 'scaler'
 
     shpd = ogr.GetDriverByName('ESRI Shapefile')
-
     if input_runlevel < 7:
         if hierarchical_classif is False :
             for cfield in cfieldlist:
                 if '-tensorflow' not in params:
-                    model_file = os.apth.join(model_fld, classifier+ '_' + cfield + '.model')
+                    model_file = os.path.join(model_fld, classifier+ '_' + cfield + '.model')
                     if not os.path.exists(model_file):
                         warnings.warn('Error: Model file ' + model_file + ' not found. Skipping.')
                         continue
@@ -924,9 +892,24 @@ def main(argv):
                             keepFields(cshp,map_out,['Segment_ID','p'+cfield,'confidence'])
                             shpd.DeleteDataSource(cshp)
                             map_list.append(map_out)
+
+                        if rasout == 'VRT':
+                            if not os.path.exists(final_fld + '/MAPS'):
+                                os.mkdir(final_fld + '/MAPS')
+                            if not os.path.exists(final_fld + '/MAPS/RASTER_' + cfield):
+                                os.mkdir(final_fld + '/MAPS/RASTER_' + cfield)
+                            ras_list = []
+                            cmd_list = []
+                            for map,ref in zip(map_list,ref_list):
+                                ras_list.append(final_fld + '/MAPS/RASTER_' + cfield + '/' + os.path.basename(map).replace('.shp', '.tif'))
+                                cmd = ['otbcli_Rasterization', '-in', map, '-im', ref, '-mode', 'attribute', '-mode.attribute.field', 'p'+cfield, '-out', ras_list[-1]]
+                                cmd_list.append(cmd)
+                            queuedProcess(cmd_list,N_processes=N_proc,shell=sh)
+
+                            cmd = ['gdalbuildvrt', '-srcnodata', '0', '-vrtnodata', '0', final_fld + '/MAPS/RASTER_' + cfield + '/Classif_' + cfield + '.vrt'] + ras_list
+                            subprocess.call(cmd, shell=sh)
                 else :
                     model_file = os.path.join(model_fld, classifier + '_' + cfield + '.h5')
-                    print (model_file)
                     scaler_files = glob.glob(os.path.join(model_fld, classifier + '_scaler_' + cfield + '*.joblib'))
                     csv_classes = os.path.join(model_fld, classifier + '_class_'+ cfield +'.csv')
                     
@@ -960,19 +943,21 @@ def main(argv):
 
                         map_list = deepClassify(shp_list,'p'+cfield,scaler_files,model_file,csv_classes,final_fld + '/MAPS/VECTOR_' + cfield + '/','_' + cfield,var_list,Nproc=N_proc,compute_confidence=comp_conf,feat_folder=feat_fld+'/*', seg_src=seg_src)
 
-                if rasout == 'VRT':
-                    if not os.path.exists(final_fld + '/MAPS/RASTER_' + cfield):
-                        os.mkdir(final_fld + '/MAPS/RASTER_' + cfield)
-                    ras_list = []
-                    cmd_list = []
-                    for map,ref in zip(map_list,ref_list):
-                        ras_list.append(final_fld + '/MAPS/RASTER_' + cfield + '/' + os.path.basename(map).replace('.shp', '.tif'))
-                        cmd = ['otbcli_Rasterization', '-in', map, '-im', ref, '-mode', 'attribute', '-mode.attribute.field', 'p'+cfield, '-out', ras_list[-1]]
-                        cmd_list.append(cmd)
-                    queuedProcess(cmd_list,N_processes=N_proc,shell=sh)
-
-                    cmd = ['gdalbuildvrt', '-srcnodata', '0', '-vrtnodata', '0', final_fld + '/MAPS/RASTER_' + cfield + '/Classif_' + cfield + '.vrt'] + ras_list
-                    subprocess.call(cmd, shell=sh)
+                        if rasout == 'VRT':
+                            if not os.path.exists(final_fld + '/MAPS'):
+                                os.mkdir(final_fld + '/MAPS')
+                            if not os.path.exists(final_fld + '/MAPS/RASTER_' + cfield):
+                                os.mkdir(final_fld + '/MAPS/RASTER_' + cfield)
+                            ras_list = []
+                            cmd_list = []
+                            for map,ref in zip(map_list,ref_list):
+                                ras_list.append(final_fld + '/MAPS/RASTER_' + cfield + '/' + os.path.basename(map).replace('.shp', '.tif'))
+                                cmd = ['otbcli_Rasterization', '-in', map, '-im', ref, '-mode', 'attribute', '-mode.attribute.field', 'p'+cfield, '-out', ras_list[-1]]
+                                cmd_list.append(cmd)
+                            queuedProcess(cmd_list,N_processes=N_proc,shell=sh)
+
+                            cmd = ['gdalbuildvrt', '-srcnodata', '0', '-vrtnodata', '0', final_fld + '/MAPS/RASTER_' + cfield + '/Classif_' + cfield + '.vrt'] + ras_list
+                            subprocess.call(cmd, shell=sh)
 
         else :
             from classificationWorkflow import Hclassify
@@ -1021,8 +1006,11 @@ def main(argv):
                     shpd.DeleteDataSource(cshp)
                     map_list.append(map_out)
                 map_list.sort()
+
                 if rasout == 'VRT':
                     for cfield in cfieldlist :
+                        if not os.path.exists(final_fld + '/MAPS'):
+                            os.mkdir(final_fld + '/MAPS')
                         if not os.path.exists(final_fld + '/MAPS/RASTER_' + cfield):
                             os.mkdir(final_fld + '/MAPS/RASTER_' + cfield)
                         ras_list = []