diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py
index 243e1be8cb9d04de76c3f9c2bb9af426eb616da4..7a0d4eb474a6a772ce5425f78c0d7d380ffc9705 100644
--- a/Learning/ObjectBased.py
+++ b/Learning/ObjectBased.py
@@ -1,6 +1,7 @@
 import glob
 
 import numpy as np
+import pandas as pd
 
 from OBIA.OBIABase import *
 from sklearn.model_selection import StratifiedGroupKFold, GroupShuffleSplit
@@ -16,63 +17,55 @@ class ObjectBasedClassifier:
         for ras in user_feature_list:
             self.obia_base.add_raster_for_stats(ras)
         self.obia_base.populate_ref_db()
-        L, X, Y, G, p2, p98 = self.obia_base.get_reference_db_as_training_base()
-        self.training_base = {
-            'obj_id': L,
-            'X': X,
-            'Y': Y,
-            'groups': G,
-            'perc2': p2,
-            'perc98': p98,
-            'folds': []
-        }
+        self.training_base = self.obia_base.get_reference_db_as_training_base(class_field=ref_class_field)
+        self.training_base['folds'] = []
         return
 
-    def gen_k_folds(self, k):
+    def gen_k_folds(self, k, class_field='class'):
         sgk = StratifiedGroupKFold(n_splits=k, shuffle=True)
         for tr_i, ts_i in sgk.split(self.training_base['X'],
-                                    self.training_base['Y'],
+                                    self.training_base[class_field],
                                     self.training_base['groups']):
             self.training_base['folds'].append((tr_i, ts_i))
         # check if all classes are in all splits
-        n_classes = len(np.unique(self.training_base['Y']))
+        n_classes = len(np.unique(self.training_base[class_field]))
         ok = True
         for f in self.training_base['folds']:
-            ok &= (len(np.unique(self.training_base['Y'][f[0]])) == n_classes and
-                   len(np.unique(self.training_base['Y'][f[1]])) == n_classes)
+            ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and
+                   len(np.unique(self.training_base[class_field][f[1]])) == n_classes)
         if not ok:
             self.training_base['folds'] = []
             raise Exception("Not all classes are present in each fold/split.\n"
                             "Please check that you have enough groups (e.g. 2 x n_folds) per class.")
         return
 
-    def gen_hold_out(self, test_train_ratio=0.5, n_splits=1):
+    def gen_hold_out(self, test_train_ratio=0.5, n_splits=1, class_field='class'):
         gss = GroupShuffleSplit(n_splits=n_splits, test_size=test_train_ratio)
         for tr_i, ts_i in gss.split(self.training_base['X'],
-                                    self.training_base['Y'],
+                                    self.training_base[class_field],
                                     self.training_base['groups']):
             self.training_base['folds'].append((tr_i, ts_i))
         # check if all classes are in all splits
-        n_classes = len(np.unique(self.training_base['Y']))
+        n_classes = len(np.unique(self.training_base[class_field]))
         ok = True
         for f in self.training_base['folds']:
-            ok &= (len(np.unique(self.training_base['Y'][f[0]])) == n_classes and
-                   len(np.unique(self.training_base['Y'][f[1]])) == n_classes)
+            ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and
+                   len(np.unique(self.training_base[class_field][f[1]])) == n_classes)
         if not ok:
             self.training_base['folds'] = []
             raise Exception("Not all classes are present in each split.\n"
                             "Please check that you have enough groups (e.g. 2 x (1/test_train_ratio)) per class.")
 
-    def train_RF(self, n_estimators, return_true_vs_pred=False):
+    def train_RF(self, n_estimators, class_field='class', return_true_vs_pred=False):
         assert('folds' in self.training_base.keys())
         models = []
         results = []
         yt_yp = []
         for tr_i, ts_i in tqdm(self.training_base['folds'], desc='Training'):
             models.append(RandomForestClassifier(n_estimators=n_estimators))
-            models[-1].fit(self.training_base['X'][tr_i], self.training_base['Y'][tr_i])
+            models[-1].fit(self.training_base['X'][tr_i], self.training_base[class_field][tr_i])
             l, c = self.training_base['obj_id'][ts_i], models[-1].predict(self.training_base['X'][ts_i])
-            y_true, y_pred = self.obia_base.true_pred_bypixel(l, c)
+            y_true, y_pred = self.obia_base.true_pred_bypixel(l, c, class_field)
             if return_true_vs_pred:
                 yt_yp.append((y_true, y_pred))
             else:
@@ -81,12 +74,14 @@ class ObjectBasedClassifier:
                         'conf_matrix': confusion_matrix(y_true, y_pred),
                         'accuracy': accuracy_score(y_true, y_pred),
                         'kappa' : cohen_kappa_score(y_true, y_pred),
-                        'p_r_f1': precision_recall_fscore_support(y_true, y_pred, zero_division=0)
+                        'p_r_f1': precision_recall_fscore_support(y_true, y_pred, zero_division=0),
+                        'importances' : models[-1].feature_importances_
                     }
                 )
         if return_true_vs_pred:
             return models, yt_yp
         else:
+            all_imp = np.vstack([x['importances'] for x in results])
             summary = {
                 'accuracy_mean': np.mean([x['accuracy'] for x in results]),
                 'accuracy_std': np.std([x['accuracy'] for x in results]),
@@ -97,7 +92,9 @@ class ObjectBasedClassifier:
                 'rec_mean': np.mean([x['p_r_f1'][1] for x in results], axis=0),
                 'rec_std': np.std([x['p_r_f1'][1] for x in results], axis=0),
                 'f1_mean': np.mean([x['p_r_f1'][2] for x in results], axis=0),
-                'f1_std': np.std([x['p_r_f1'][2] for x in results], axis=0)
+                'f1_std': np.std([x['p_r_f1'][2] for x in results], axis=0),
+                'importance_mean': {k:v for k, v in zip(self.obia_base.get_vars(), np.mean(all_imp, axis=0))},
+                'importance_std': {k:v for k, v in zip(self.obia_base.get_vars(), np.std(all_imp, axis=0))}
             }
             return models, summary, results
 
@@ -124,16 +121,21 @@ class ObjectBasedClassifier:
 #TEST CODE
 def run_test():
     obc = ObjectBasedClassifier('/DATA/Moringa_Sample/Parakou/output/segmentation/segmentation.tif',
-                                '/DATA/Moringa_Sample/Parakou/input/REF/ref.shp',
+                                '/DATA/Moringa_Sample/Parakou/input/REF/ref_l2.shp',
                                 ['/DATA/Moringa_Sample/Parakou/output/S2_processed/T31PDL/*/*FEAT.tif'],
-                                ['/DATA/Moringa_Sample/Parakou/input/THR/THR_SPOT6.tif'])
+                                ['/DATA/Moringa_Sample/Parakou/input/THR/THR_SPOT6.tif'],
+                                ref_class_field=['class', 'Class_L1a'])
     '''
     obc = ObjectBasedClassifier('/DATA/Benin/OBSYDYA_data/MORINGA/SEGMENTATION/segmentation.tif',
                                 '/DATA/Benin/OBSYDYA_data/MORINGA/reference/BD_OBSYDYA_2022_ParakouNdali_v0.2.shp',
                                 ['/DATA/Benin/OBSYDYA_data/MORINGA/basefolder/FEAT/S2_THEIA_FEAT/S2_THEIA_MOSAIC_*.tif'],
                                 glob.glob('/DATA/Benin/OBSYDYA_data/MORINGA/ext_features'))
     '''
-    obc.gen_k_folds(5)
-    m,yt_yp = obc.train_RF(100, return_true_vs_pred=True)
-    obc.classify(m, '/DATA/Moringa_Sample/Parakou/output/classification/firstmap.tif')
-    return yt_yp
+    obc.gen_k_folds(5, class_field='Class_L1a')
+    #obc.gen_hold_out(0.2, class_field='Class_L1a')
+    #m,yt_yp = obc.train_RF(100, return_true_vs_pred=True)
+    m1, s1, r1 = obc.train_RF(100, class_field='class')
+    m2, s2, r2 = obc.train_RF(100, class_field='Class_L1a')
+    obc.classify(m1, '/DATA/Moringa_Sample/Parakou/output/classification/firstmap_l1.tif')
+    obc.classify(m2, '/DATA/Moringa_Sample/Parakou/output/classification/firstmap_l2.tif')
+    return m1,s1,r1,m2,s2,r2
diff --git a/OBIA/OBIABase.py b/OBIA/OBIABase.py
index e45f046f081a6b9cb820ea28077b138591ce055a..48d8889d6514d4e2988c69d91741cb1088d13d79 100644
--- a/OBIA/OBIABase.py
+++ b/OBIA/OBIABase.py
@@ -57,6 +57,8 @@ class OBIABase:
         self.output_map = None
 
     def init_ref_db(self, vector_file, id_field, class_field):
+        if isinstance(class_field, str):
+            class_field = [class_field]
         ras_id = otb.Registry.CreateApplication('Rasterization')
         ras_id.SetParameterString('in', vector_file)
         ras_id.SetParameterString('im', self.object_layer)
@@ -65,21 +67,21 @@ class OBIABase:
         ras_id.Execute()
         #ids = ras_id.GetImageAsNumpyArray('out')
 
-        ras_cl = otb.Registry.CreateApplication('Rasterization')
-        ras_cl.SetParameterString('in', vector_file)
-        ras_cl.SetParameterString('im', self.object_layer)
-        ras_cl.SetParameterString('mode', 'attribute')
-        ras_cl.SetParameterString('mode.attribute.field', class_field)
-        ras_cl.Execute()
-        #cls = ras_cl.GetImageAsNumpyArray('out')
+        ras_cl = []
+        for cf in class_field:
+            ras_cl.append(otb.Registry.CreateApplication('Rasterization'))
+            ras_cl[-1].SetParameterString('in', vector_file)
+            ras_cl[-1].SetParameterString('im', self.object_layer)
+            ras_cl[-1].SetParameterString('mode', 'attribute')
+            ras_cl[-1].SetParameterString('mode.attribute.field', cf)
+            ras_cl[-1].Execute()
 
         in_seg = to_otb_pipeline(self.object_layer)
-        #obj = in_seg.GetImageAsNumpyArray('out')
 
         intensity_img = otb.Registry.CreateApplication('ConcatenateImages')
         intensity_img.AddImageToParameterInputImageList('il', in_seg.GetParameterOutputImage('out'))
         intensity_img.AddImageToParameterInputImageList('il', ras_id.GetParameterOutputImage('out'))
-        intensity_img.AddImageToParameterInputImageList('il', ras_cl.GetParameterOutputImage('out'))
+        [intensity_img.AddImageToParameterInputImageList('il', rcl.GetParameterOutputImage('out')) for rcl in ras_cl]
         intensity_img.Execute()
 
         ref_ol = otb.Registry.CreateApplication('BandMath')
@@ -90,7 +92,7 @@ class OBIABase:
         self.ref_obj_layer_pipe = [in_seg, ras_id, ref_ol]
 
         self.ref_db = pd.DataFrame(data=[],
-                                   columns=['area', 'orig_label', 'polygon_id', 'class'],
+                                   columns=['area', 'orig_label', 'polygon_id'] + class_field,
                                    index=[])
         r = otb.itkRegion()
         for tn, t in tqdm(self.tiles.items(), desc='Init. Ref. DB', total=len(self.tiles)):
@@ -276,28 +278,35 @@ class OBIABase:
         vds.to_file(out_vector)
         return
 
-    def get_reference_db_as_training_base(self):
+    def get_vars(self):
+        return [item for sublist in self.raster_var_names for item in sublist]
+
+    def get_reference_db_as_training_base(self, class_field='class'):
+        if isinstance(class_field, str):
+            class_field = [class_field]
         assert(self.ref_db is not None and len(self.raster_var_names)>0)
-        vars = [item for sublist in self.raster_var_names for item in sublist]
-        L = self.ref_db['orig_label'].to_numpy(dtype=int)
-        X = self.ref_db[vars].to_numpy()
+        vars = self.get_vars()
+        out = {}
+        out['obj_id'] = self.ref_db['orig_label'].to_numpy(dtype=int)
+        out['X'] = self.ref_db[vars].to_numpy()
         # compute percentiles and normalize
-        p2 = np.zeros(X.shape[1])
-        p98 = np.zeros(X.shape[1])
+        out['perc2'] = np.zeros(out['X'].shape[1])
+        out['perc98'] = np.zeros(out['X'].shape[1])
         for g in self.raster_groups:
-            tmp = X[:,g]
+            tmp = out['X'][:,g]
             m,M = np.percentile(tmp, [2, 98])
             if isinstance(g, list):
                 for x in g:
-                    p2[x] = m
-                    p98[x] = M
+                    out['perc2'][x] = m
+                    out['perc98'][x] = M
             else:
-                p2[g] = m
-                p98[g] = M
-            X[:,g] = (tmp - m)/(M - m)
-        Y = self.ref_db['class'].to_numpy(dtype=int)
-        G = self.ref_db['polygon_id'].to_numpy(dtype=int)
-        return L,X,Y,G,p2,p98
+                out['perc2'][g] = m
+                out['perc98'][g] = M
+            out['X'][:,g] = (tmp - m)/(M - m)
+        for cf in class_field:
+            out[cf] = self.ref_db[cf].to_numpy(dtype=int)
+        out['groups'] = self.ref_db['polygon_id'].to_numpy(dtype=int)
+        return out
 
     def tiled_data(self, normalize=None):
         vars = [item for sublist in self.raster_var_names for item in sublist]
@@ -355,7 +364,7 @@ class OBIABase:
         for tn in tqdm(self.tiles.keys(), desc="Writing output map"):
             self.populate_map(tn,obj_id,classes,output_file,compress)
 
-    def true_pred_bypixel(self, labels, predicted_classes):
+    def true_pred_bypixel(self, labels, predicted_classes, class_field='class'):
         pred_c = np.zeros(np.max(self.ref_db['orig_label']).astype(int)+1)
         pred_c[labels] = predicted_classes
         support = []
@@ -365,7 +374,7 @@ class OBIABase:
         pred = pred_c[support]
         true_c = np.zeros(np.max(self.ref_db['orig_label']).astype(int)+1)
         # ATTENTION: works if "labels" is sorted (as provided by get_reference_...)
-        true_c[labels] = self.ref_db.loc[self.ref_db['orig_label'].isin(labels),'class'].to_numpy(dtype=int)
+        true_c[labels] = self.ref_db.loc[self.ref_db['orig_label'].isin(labels),class_field].to_numpy(dtype=int)
         true = true_c[support]
         return pred[pred>0], true[pred>0]