diff --git a/classificationWorkflow.py b/classificationWorkflow.py
index bd74d10731a580facabaa4111736c2c738851147..8e41e1a454660028d1531f84104173a2235e6cc4 100644
--- a/classificationWorkflow.py
+++ b/classificationWorkflow.py
@@ -164,18 +164,38 @@ def baseDeepTraining(shp,code,flds, params, model_file, out_scaler, csv_classes,
         model.add(tf.keras.layers.Dense(nb_class, activation="softmax"))
 
         model.compile(loss="sparse_categorical_crossentropy",
-            optimizer="sgd",
+            optimizer="Adam",
             metrics=["accuracy"])
 
-    history = model.fit(cfeats, ctargets, epochs=epochs)
-    model.save(model_file)
-    joblib.dump(scaler,out_scaler)
+        history = model.fit(cfeats, ctargets, epochs=epochs)
+        model.save(model_file)
+        joblib.dump(scaler,out_scaler)
+
+    # elif params[1].lowercase() == 'm3fusion':
+    #     from tf.keras.callbacks import CSVLogger
+    #     weights_path = model_file.replace('.','_weights.')
+    #     log_path = model_file.replace('.h5','.log')
+    #     log = CSVLogger(log_path, separator=',', append=False)
+    #     n_timestamps,ts_size = TFmodels.retrieveTSinfo(os.path.join(os.path.dirname(shp),'field_names.csv'))
+    #     patch_size = 25
+    #     n_bands = 4
+
+    #     model = TFmodelss.M3Fusion(nb_class,n_timestamps,patch_size,n_bands)ts_size,patch_size,n_bands
+    #     model.compile(optimizer="Adam",
+    #                   loss="sparse_categorical_crossentropy",
+    #                   loss_weights=[1,0.3,0.3],
+    #                   metrics=['accuracy']
+    #                   )
+
+    #     model.fit([x_train_rnn,x_train_cnn],[labels,labels,labels], epochs=200, batch_size=128, validation_data=([x_valid_rnn,x_valid_cnn],[labels_val,labels_val,labels_val]), callbacks=[log])
+    #     model.save(model_file)
+    #     model.save_weights(weights_path)
+
     with open(csv_classes,'w') as file:
         writer = csv.writer(file)
         for i,c in enumerate(classes):
             writer.writerow([i,c])
 
-
 def baseDeepClassify(shp, model_file, code,flds, out_file, compute_confidence=False):
     import geopandas
     import tensorflow as tf
@@ -209,7 +229,7 @@ def baseDeepClassify(shp, model_file, code,flds, out_file, compute_confidence=Fa
         out_file = shp
     ds.to_file(out_file)
 
-def deepTraining(shp,code,model_fld,params,feat,feat_mode = 'list', epochs=20):
+def deepTraining(data,code,model_fld,params,feat,feat_mode = 'list', epochs=20):
     import geopandas
     import tensorflow as tf
     from sklearn.preprocessing import StandardScaler
@@ -218,26 +238,29 @@ def deepTraining(shp,code,model_fld,params,feat,feat_mode = 'list', epochs=20):
 
     if params[1] == 'standard':
         ext = "tfDense"
+    elif params[1].lower() == 'm3fusion':
+        ext = "M3F"
+        if len(data) not in [3,6]:
+            sys.exit("Wrong number of arguments for train/valid dataset, must be 3 or 6")
 
     model_file = os.path.join(model_fld, ext + '_' + code + '.h5')
     scaler_file = os.path.join(model_fld, ext + '_scaler_' + code + '.joblib')
+    scaler_files = []
     csv_classes = os.path.join(model_fld, ext + '_class_'+ code + '.csv')
 
-    #Read training shp with geopandas
-    ds = geopandas.read_file(shp)
-    #Extract feats and 
-    feats = ds.filter(feat,axis=1)
-    targets = ds[code]
-    nb_class = len(targets.unique())
-    classes = np.array(sorted(targets.unique()))
-    nb_class = len(targets.unique())
-    ctargets = np.array([np.where(classes==i)[0][0] for i in targets])
-
-    scaler = StandardScaler()
-    scaler.fit(feats)
-    cfeats = scaler.transform(feats)
-
     if params[1] == 'standard':
+        #Read training shp with geopandas
+        ds = geopandas.read_file(data)
+        #Extract feats and 
+        feats = ds.filter(feat,axis=1)
+        targets = ds[code]
+        nb_class = len(targets.unique())
+        classes = np.array(sorted(targets.unique()))
+        ctargets = np.array([np.where(classes==i)[0][0] for i in targets])
+
+        scaler = StandardScaler()
+        scaler.fit(feats)
+        cfeats = scaler.transform(feats)
         # Model init
         model = tf.keras.models.Sequential()
 
@@ -251,15 +274,70 @@ def deepTraining(shp,code,model_fld,params,feat,feat_mode = 'list', epochs=20):
             optimizer="sgd",
             metrics=["accuracy"])
 
-    history = model.fit(cfeats, ctargets, epochs=epochs)
-    model.save(model_file)
-    joblib.dump(scaler,scaler_file)
+        history = model.fit(cfeats, ctargets, epochs=epochs)
+        model.save(model_file)
+        joblib.dump(scaler,scaler_file)
+        scaler_files.append(scaler_file)
+    elif params[1].lower() == 'm3fusion':
+        from tensorflow.keras.callbacks import CSVLogger
+        import TFmodels
+        # weights_path = model_file.replace('.','_weights.')
+        log_path = model_file.replace('.h5','.log')
+        log = CSVLogger(log_path, separator=',', append=False)
+        ts_size, n_timestamps = TFmodels.retrieveTSinfo(feat)
+        print(ts_size, n_timestamps)
+        patch_size = 25
+        n_bands = 4
+
+        x_train_rnn = np.load(data[0])
+        scaler_rnn = StandardScaler()
+        scaler_rnn.fit(x_train_rnn)
+        x_train_rnn = scaler_rnn.transform(x_train_rnn)
+        x_train_rnn = np.reshape(x_train_rnn, (x_train_rnn.shape[0], n_timestamps, -1))
+
+        x_train_cnn = np.load(data[1])
+        scaler_cnn = StandardScaler()
+        scaler_cnn.fit(x_train_cnn.reshape(-1,patch_size*patch_size*n_bands))
+        x_train_cnn = scaler_cnn.transform(x_train_cnn.reshape(-1,patch_size*patch_size*n_bands)).reshape(-1,patch_size,patch_size,n_bands)
+
+        labels = np.load(data[2])
+        nb_class = len(np.unique(labels))
+        classes = np.array(sorted(np.unique(labels)))
+        clabels = np.array([np.where(classes==i)[0][0] for i in labels])
+
+        model = TFmodels.M3FusionModel(nb_class,n_timestamps,ts_size,patch_size,n_bands)
+        opt = tf.keras.optimizers.Adam(learning_rate=0.0002)
+        model.compile(optimizer=opt,
+                      loss="sparse_categorical_crossentropy",
+                      loss_weights=[1,0.3,0.3],
+                      metrics=['accuracy']
+                      )
+
+        if len(data) == 3:
+            model.fit([x_train_rnn,x_train_cnn],[clabels,clabels,clabels], epochs=epochs, batch_size=128, callbacks=[log])
+        elif len(data) == 6:
+            x_valid_rnn = np.load(data[3])
+            x_valid_rnn = scaler_rnn.transform(x_valid_rnn)
+            x_valid_rnn = np.reshape(x_valid_rnn, (x_valid_rnn.shape[0], n_timestamps, -1))
+            x_valid_cnn = np.load(data[4])
+            x_valid_cnn = scaler_cnn.transform(x_valid_cnn.reshape(-1,patch_size*patch_size*n_bands)).reshape(-1,patch_size,patch_size,n_bands)
+            labels_val = np.load(data[5])
+            if nb_class != len(np.unique(labels_val)):
+                sys.exit("Missing classes in validation dataset : {}".format(data[5]))
+            clabels_val = np.array([np.where(classes==i)[0][0] for i in labels_val])
+            model.fit(x=[x_train_rnn,x_train_cnn],y=[clabels,clabels,clabels], epochs=epochs, batch_size=128, validation_data=([x_valid_rnn,x_valid_cnn],[clabels_val,clabels_val,clabels_val]), callbacks=[log])
+        # model.save(model_file, save_format="tf")
+        model.save_weights(model_file)
+        joblib.dump(scaler_rnn,scaler_file.replace(".joblib","_rnn.joblib"))
+        joblib.dump(scaler_cnn,scaler_file.replace(".joblib","_cnn.joblib"))
+        scaler_files.append(scaler_file.replace(".joblib","_rnn.joblib"))
+        scaler_files.append(scaler_file.replace(".joblib","_cnn.joblib"))
     with open(csv_classes,'w') as file :
         writer = csv.writer(file)
         for i,c in enumerate(classes):
             writer.writerow([i,c])
 
-    return model_file, scaler_file, csv_classes
+    return model_file, scaler_files, csv_classes
 
 def deepClassify(shp_list,code,scaler,model_file, csv_classes,out_fld,out_ext,feat,feat_mode = 'list',Nproc=1,compute_confidence=False):
     import geopandas
diff --git a/launchChain.py b/launchChain.py
index a026b833c12e69928b69b168ce2ec2c112bb0017..ed8a0cddb365afba392baf95f2e990da4d38824b 100644
--- a/launchChain.py
+++ b/launchChain.py
@@ -555,103 +555,106 @@ def main(argv):
         feat_groups_str = '::'.join([','.join(x) for x in feat_groups])
 
         # generating GT samples and merging tiles
-        if os.path.exists(samples_fld + '/GT_samples.shp'):
-            warnings.warn('File ' + samples_fld + '/GT_samples.shp' + ' already exists, skipping zonal stats on reference.')
-        else:
-            gttl = sorted(glob.glob(seg_fld + '/GT_segmentation_*.shp'))
-            ogttl = []
 
-            ext_flds = getFieldNames(gttl[0])
+        params = config.get('TRAINING CONFIGURATION', 'parameters').split(' ')
+        if '-classifier' in params or ('-tensorflow' in params and 'standard' in params) :
+            if os.path.exists(samples_fld + '/GT_samples.shp'):
+                warnings.warn('File ' + samples_fld + '/GT_samples.shp' + ' already exists, skipping zonal stats on reference.')
+            else:
+                gttl = sorted(glob.glob(seg_fld + '/GT_segmentation_*.shp'))
+                ogttl = []
 
-            # BEGIN ** GT_SAMPLES **
-            cmd_list = []
+                ext_flds = getFieldNames(gttl[0])
 
-            for tl in gttl:
-                tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
-                for fn in tocp:
-                    shutil.copyfile(fn, samples_fld + '/' + os.path.basename(fn))
-                otl = samples_fld + '/' + os.path.basename(tl)
-                ogttl.append(otl)
-                cmd_list.append(
-                    ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
-                     '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
+                # BEGIN ** GT_SAMPLES **
+                cmd_list = []
 
-            queuedProcess(cmd_list, N_processes=1, shell=sh)
-            # END ** GT_SAMPLES **
+                for tl in gttl:
+                    tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
+                    for fn in tocp:
+                        shutil.copyfile(fn, samples_fld + '/' + os.path.basename(fn))
+                    otl = samples_fld + '/' + os.path.basename(tl)
+                    ogttl.append(otl)
+                    cmd_list.append(
+                        ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
+                         '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
 
-            var_keys = [x for x in getFieldNames(ogttl[0]) if x not in ext_flds]
+                queuedProcess(cmd_list, N_processes=1, shell=sh)
+                # END ** GT_SAMPLES **
 
-            with open(samples_fld + '/field_names.csv', 'w') as varfile:
-                vw = csv.writer(varfile)
-                for x in zip(var_keys, feat_to_vars):
-                    vw.writerow(x)
+                var_keys = [x for x in getFieldNames(ogttl[0]) if x not in ext_flds]
 
-            cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile']
+                with open(samples_fld + '/field_names.csv', 'w') as varfile:
+                    vw = csv.writer(varfile)
+                    for x in zip(var_keys, feat_to_vars):
+                        vw.writerow(x)
 
-            cmd = cmd_base + [samples_fld + '/GT_samples.shp', ogttl[0]]
-            subprocess.call(cmd, shell=sh)
+                cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile']
 
-            for k in range(1, len(ogttl)):
-                cmd = cmd_base + ['-update', '-append', samples_fld + '/GT_samples.shp', ogttl[k], '-nln', 'GT_samples']
+                cmd = cmd_base + [samples_fld + '/GT_samples.shp', ogttl[0]]
                 subprocess.call(cmd, shell=sh)
 
-            for otl in ogttl:
-                shpd.DeleteDataSource(otl)
-                os.remove(otl.replace('.shp','.tif'))
+                for k in range(1, len(ogttl)):
+                    cmd = cmd_base + ['-update', '-append', samples_fld + '/GT_samples.shp', ogttl[k], '-nln', 'GT_samples']
+                    subprocess.call(cmd, shell=sh)
 
-        if ch_mode <= 0 or ch_mode == 2:
-            # generating VAL samples and merging tiles
-            gttl = sorted(glob.glob(seg_fld + '/VAL_segmentation_*.shp'))
-            ogttl = []
+                for otl in ogttl:
+                    shpd.DeleteDataSource(otl)
+                    os.remove(otl.replace('.shp','.tif'))
 
-            # BEGIN ** VAL_SAMPLES **
-            cmd_list = []
+            if ch_mode <= 0 or ch_mode == 2:
+                # generating VAL samples and merging tiles
+                gttl = sorted(glob.glob(seg_fld + '/VAL_segmentation_*.shp'))
+                ogttl = []
 
-            for tl in gttl:
-                tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
-                for fn in tocp:
-                    shutil.copyfile(fn, val_fld + '/' + os.path.basename(fn))
-                otl = val_fld + '/' + os.path.basename(tl)
-                ogttl.append(otl)
-                cmd_list.append(
-                    ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
-                     '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
+                # BEGIN ** VAL_SAMPLES **
+                cmd_list = []
 
-            queuedProcess(cmd_list, N_processes=1, shell=sh)
-            # END ** VAL_SAMPLES **
+                for tl in gttl:
+                    tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
+                    for fn in tocp:
+                        shutil.copyfile(fn, val_fld + '/' + os.path.basename(fn))
+                    otl = val_fld + '/' + os.path.basename(tl)
+                    ogttl.append(otl)
+                    cmd_list.append(
+                        ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
+                         '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
 
-            cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile']
+                queuedProcess(cmd_list, N_processes=1, shell=sh)
+                # END ** VAL_SAMPLES **
 
-            cmd = cmd_base + [val_fld + '/VAL_samples.shp', ogttl[0]]
-            subprocess.call(cmd, shell=sh)
+                cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile']
 
-            for k in range(1, len(ogttl)):
-                cmd = cmd_base + ['-update', '-append', val_fld + '/VAL_samples.shp', ogttl[k], '-nln', 'VAL_samples']
+                cmd = cmd_base + [val_fld + '/VAL_samples.shp', ogttl[0]]
                 subprocess.call(cmd, shell=sh)
 
-            for otl in ogttl:
-                shpd.DeleteDataSource(otl)
-                os.remove(otl.replace('.shp', '.tif'))
-
-        if ch_mode > 0 or force_sample_generation is True:
-            # generating map samples (TBD)
-            gttl = sorted(glob.glob(seg_fld + '/segmentation_*.shp'))
-            ogttl = []
-
-            # BEGIN ** SAMPLES **
-            cmd_list = []
-
-            for tl in gttl:
-                tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
-                for fn in tocp:
-                    shutil.copyfile(fn, test_fld + '/' + os.path.basename(fn))
-                otl = test_fld + '/' + os.path.basename(tl)
-                ogttl.append(otl)
-                cmd_list.append(
-                    ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
-                     '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
-
-            queuedProcess(cmd_list, N_processes=N_proc, shell=sh)
+                for k in range(1, len(ogttl)):
+                    cmd = cmd_base + ['-update', '-append', val_fld + '/VAL_samples.shp', ogttl[k], '-nln', 'VAL_samples']
+                    subprocess.call(cmd, shell=sh)
+
+                for otl in ogttl:
+                    shpd.DeleteDataSource(otl)
+                    os.remove(otl.replace('.shp', '.tif'))
+
+            if ch_mode > 0 or force_sample_generation is True:
+                # generating map samples (TBD)
+                gttl = sorted(glob.glob(seg_fld + '/segmentation_*.shp'))
+                ogttl = []
+
+                # BEGIN ** SAMPLES **
+                cmd_list = []
+
+                for tl in gttl:
+                    tocp = glob.glob(os.path.splitext(tl)[0] + '.*')
+                    for fn in tocp:
+                        shutil.copyfile(fn, test_fld + '/' + os.path.basename(fn))
+                    otl = test_fld + '/' + os.path.basename(tl)
+                    ogttl.append(otl)
+                    cmd_list.append(
+                        ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D',
+                         '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean'])
+
+                queuedProcess(cmd_list, N_processes=N_proc, shell=sh)
             # END ** SAMPLES **
 
         if single_step:
@@ -722,6 +725,8 @@ def main(argv):
                             val_out_list = deepClassify(val_list, 'p' + cfield, scaler, mfn, csv_classes, mfld,
                                                         '_' + cfield, var_list)
                             os.rename(val_out_list[0], val_out)
+                        else :
+                            sys.exit('\n This classification method is not yet available for kfold treatement')
                         txt_out = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp','_' + cfield + '_report.txt')
                         classes, cm, acc, kappa, prf = surfaceValidation(test_folds[i], val_out, val_out_check, cfield)
                         formatValidationTxt(classes, cm, acc, kappa, prf, txt_out)
@@ -739,8 +744,53 @@ def main(argv):
                 if params[0] == "-classifier":
                     training(samples_fld + '/GT_samples.shp',cfield,model_fld,params,var_list)
                 elif params[0] == '-tensorflow':
-                    deepTraining(samples_fld + '/GT_samples.shp',cfield,model_fld,params,var_list)
-
+                    if params[1] == 'standard':
+                        deepTraining(samples_fld + '/GT_samples.shp',cfield,model_fld,params,var_list)
+                    elif params[1].lower() == 'm3fusion':
+                        import Moringa2DL
+                        epochs=int(params[2])
+                        tf_data = Moringa2DL.generateAll(feat_fld+'/*',reference,seg_src,cfield,ofld=samples_fld)
+                        # tf_data=['/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/GT_SAMPLES/train_code_1_rnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/GT_SAMPLES/train_code_1_cnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/GT_SAMPLES/train_code_1_label.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/VAL_SAMPLES/valid_code_1_rnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/VAL_SAMPLES/valid_code_1_cnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/VAL_SAMPLES/valid_code_1_label.npy']
+                        # gttl = sorted(glob.glob(seg_fld + '/GT_segmentation_*.shp'))
+                        # gt_rnn_feat = None
+                        # gt_rnn_label = np.zeros(0)
+                        # gt_cnn_feat = None
+                        # gt_cnn_label = np.zeros(0)
+                        # for gt in gttl :
+                        #     rnn_feat, rnn_label = Moringa2DL.generateAnnotatedDataset(feat_fld, gt, cfield,sname=os.path.splitext(os.path.basename(gt))[0]+'_rnn_'+cfield,ofld=samples_fld)
+                        #     if gt_rnn_feat is None : 
+                        #         gt_rnn_feat = np.zeros((0,rnn_feat.shape[1]))
+                        #     np.concatenate(gt_rnn_feat,rnn_feat)
+                        #     np.concatenate(gt_rnn_label,rnn_label)
+                        #     cnn_feat, cnn_label = Moringa2DL.CNNAnnotedDataset(seg_src, gt, cfield, 25, sname=os.path.splitext(os.path.basename(gt))[0]+'_cnn_'+cfield, ofld=samples_fld)
+                        #     if gt_cnn_feat is None : 
+                        #         gt_cnn_feat = np.zeros((0,cnn_feat.shape[1],cnn_feat.shape[2],cnn_feat.shape[3]))
+                        #     np.concatenate(gt_cnn_feat,cnn_feat)
+                        #     np.concatenate(gt_cnn_label,cnn_label)
+
+                        # tf_data.append(gt_rnn_feat)
+                        # tf_data.append(gt_rnn_label)
+                        # tf_data.append(gt_cnn_feat)
+                        # tf_data.append(gt_rnn_feat)
+
+                        val_rnn_feat = None
+                        val_rnn_label = None
+                        if ch_mode == 0 or ch_mode == 2 :
+                            tf_valid_data = Moringa2DL.generateAll(feat_fld+'/*',validation,seg_src,cfield,ofld=val_fld,dstype='valid')
+                            tf_data += tf_valid_data
+                            # gttl = sorted(glob.glob(seg_fld + '/VAL_segmentation_*.shp'))
+                            # val_rnn_label = np.zeros(0)
+                            # for gt in gttl :
+                            #     rnn_feat, rnn_label = Moringa2DL.generateAnnotatedDataset(feat_fld, gt, cfield,sname=os.path.splitext(os.path.basename(gt))[0]+'_'+cfield,ofld=val_fld)
+                            #     if val_rnn_feat is None : 
+                            #         val_rnn_feat = np.zeros((0,rnn_feat.shape[1]))
+                            #     np.concatenate(val_rnn_feat,rnn_feat)
+                            #     np.concatenate(val_rnn_label,rnn_label)
+                            # tf_data.append(val_rnn_feat)
+                            # tf_data.append(val_rnn_label)
+                        deepTraining(tf_data,cfield,model_fld,params,var_list,epochs=epochs)
+                    else:
+                        sys.exit("Tensorflow parameter {} is not available".format(params[1]))
             if rfimp == True:
                 for cfield in cfieldlist:
                     getVariableImportance(samples_fld + '/GT_samples.shp',var_list,cfield,samples_fld + '/var_importance_' + cfield + '.csv', nbtrees=rfimpntrees, nodesize=rfimpnodesize, mtry=rfimpmaxfeat, nruns=rfimpnruns, field_names=samples_fld + '/field_names.csv')