From 705a3e4a5b8a94d5ca5cfa2317a5e55bd09dfcb7 Mon Sep 17 00:00:00 2001
From: Raffaele Gaetano <raffaele.gaetano@cirad.fr>
Date: Thu, 1 Jun 2023 10:22:23 +0200
Subject: [PATCH] ENH: ensuring robust stratified group k-fold generation.

---
 Learning/ObjectBased.py | 55 ++++++++++++++++++++++++++++++++++++-----
 1 file changed, 49 insertions(+), 6 deletions(-)

diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py
index 737b013..17c70e3 100644
--- a/Learning/ObjectBased.py
+++ b/Learning/ObjectBased.py
@@ -1,6 +1,9 @@
 import glob
+
+import numpy as np
+
 from OBIA.OBIABase import *
-from sklearn.model_selection import StratifiedGroupKFold
+from sklearn.model_selection import StratifiedGroupKFold, GroupShuffleSplit
 from sklearn.ensemble import RandomForestClassifier
 from sklearn.metrics import confusion_matrix, accuracy_score, cohen_kappa_score, precision_recall_fscore_support
 
@@ -20,19 +23,46 @@ class ObjectBasedClassifier:
             'Y': Y,
             'groups': G,
             'perc2': p2,
-            'perc98': p98
+            'perc98': p98,
+            'folds': []
         }
         return
 
     def gen_k_folds(self, k):
-        self.training_base['folds'] = []
         sgk = StratifiedGroupKFold(n_splits=k, shuffle=True)
         for tr_i, ts_i in sgk.split(self.training_base['X'],
                                     self.training_base['Y'],
                                     self.training_base['groups']):
             self.training_base['folds'].append((tr_i, ts_i))
+        # check if all classes are in all splits
+        n_classes = len(np.unique(self.training_base['Y']))
+        ok = True
+        for f in self.training_base['folds']:
+            ok &= (len(np.unique(self.training_base['Y'][f[0]])) == n_classes and
+                   len(np.unique(self.training_base['Y'][f[1]])) == n_classes)
+        if not ok:
+            self.training_base['folds'] = []
+            raise Exception("Not all classes are present in each fold/split.\n"
+                            "Please check that you have enough groups (e.g. 2 x n_folds) per class.")
         return
 
+    def gen_hold_out(self, test_train_ratio=0.5, n_splits=1):
+        gss = GroupShuffleSplit(n_splits=n_splits, test_size=test_train_ratio)
+        for tr_i, ts_i in gss.split(self.training_base['X'],
+                                    self.training_base['Y'],
+                                    self.training_base['groups']):
+            self.training_base['folds'].append((tr_i, ts_i))
+        # check if all classes are in all splits
+        n_classes = len(np.unique(self.training_base['Y']))
+        ok = True
+        for f in self.training_base['folds']:
+            ok &= (len(np.unique(self.training_base['Y'][f[0]])) == n_classes and
+                   len(np.unique(self.training_base['Y'][f[1]])) == n_classes)
+        if not ok:
+            self.training_base['folds'] = []
+            raise Exception("Not all classes are present in each split.\n"
+                            "Please check that you have enough groups (e.g. 2 x (1/test_train_ratio)) per class.")
+
     def train_RF(self, n_estimators):
         assert('folds' in self.training_base.keys())
         models = []
@@ -50,14 +80,27 @@ class ObjectBasedClassifier:
                     'p_r_f1': precision_recall_fscore_support(y_true, y_pred, zero_division=0)
                 }
             )
-            '''
             summary = {
                 'accuracy_mean': np.mean([x['accuracy'] for x in results]),
                 'accuracy_std': np.std([x['accuracy'] for x in results]),
                 'kappa_mean': np.mean([x['kappa'] for x in results]),
                 'kappa_std': np.std([x['kappa'] for x in results]),
+                'prec_mean': np.mean([x['p_r_f1'][0] for x in results], axis=0),
+                'prec_std': np.std([x['p_r_f1'][0] for x in results], axis=0),
+                'rec_mean': np.mean([x['p_r_f1'][1] for x in results], axis=0),
+                'rec_std': np.std([x['p_r_f1'][1] for x in results], axis=0),
                 'f1_mean': np.mean([x['p_r_f1'][2] for x in results], axis=0),
                 'f1_std': np.std([x['p_r_f1'][2] for x in results], axis=0)
             }
-            '''
-        return models, results
+        return models, summary, results
+
+#TEST CODE
+def run_test():
+    obc = ObjectBasedClassifier('/DATA/Moringa_Sample/Parakou/output/segmentation/segmentation.tif',
+                                '/DATA/Moringa_Sample/Parakou/input/REF/ref.shp',
+                                ['/DATA/Moringa_Sample/Parakou/output/S2_processed/T31PDL/*/*FEAT.tif'],
+                                ['/DATA/Moringa_Sample/Parakou/input/THR/THR_SPOT6.tif'])
+    obc.gen_k_folds(5)
+    m,s,r = obc.train_RF(100)
+    print(s)
+    return
-- 
GitLab