From 705a3e4a5b8a94d5ca5cfa2317a5e55bd09dfcb7 Mon Sep 17 00:00:00 2001 From: Raffaele Gaetano <raffaele.gaetano@cirad.fr> Date: Thu, 1 Jun 2023 10:22:23 +0200 Subject: [PATCH] ENH: ensuring robust stratified group k-fold generation. --- Learning/ObjectBased.py | 55 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py index 737b013..17c70e3 100644 --- a/Learning/ObjectBased.py +++ b/Learning/ObjectBased.py @@ -1,6 +1,9 @@ import glob + +import numpy as np + from OBIA.OBIABase import * -from sklearn.model_selection import StratifiedGroupKFold +from sklearn.model_selection import StratifiedGroupKFold, GroupShuffleSplit from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import confusion_matrix, accuracy_score, cohen_kappa_score, precision_recall_fscore_support @@ -20,19 +23,46 @@ class ObjectBasedClassifier: 'Y': Y, 'groups': G, 'perc2': p2, - 'perc98': p98 + 'perc98': p98, + 'folds': [] } return def gen_k_folds(self, k): - self.training_base['folds'] = [] sgk = StratifiedGroupKFold(n_splits=k, shuffle=True) for tr_i, ts_i in sgk.split(self.training_base['X'], self.training_base['Y'], self.training_base['groups']): self.training_base['folds'].append((tr_i, ts_i)) + # check if all classes are in all splits + n_classes = len(np.unique(self.training_base['Y'])) + ok = True + for f in self.training_base['folds']: + ok &= (len(np.unique(self.training_base['Y'][f[0]])) == n_classes and + len(np.unique(self.training_base['Y'][f[1]])) == n_classes) + if not ok: + self.training_base['folds'] = [] + raise Exception("Not all classes are present in each fold/split.\n" + "Please check that you have enough groups (e.g. 2 x n_folds) per class.") return + def gen_hold_out(self, test_train_ratio=0.5, n_splits=1): + gss = GroupShuffleSplit(n_splits=n_splits, test_size=test_train_ratio) + for tr_i, ts_i in gss.split(self.training_base['X'], + self.training_base['Y'], + self.training_base['groups']): + self.training_base['folds'].append((tr_i, ts_i)) + # check if all classes are in all splits + n_classes = len(np.unique(self.training_base['Y'])) + ok = True + for f in self.training_base['folds']: + ok &= (len(np.unique(self.training_base['Y'][f[0]])) == n_classes and + len(np.unique(self.training_base['Y'][f[1]])) == n_classes) + if not ok: + self.training_base['folds'] = [] + raise Exception("Not all classes are present in each split.\n" + "Please check that you have enough groups (e.g. 2 x (1/test_train_ratio)) per class.") + def train_RF(self, n_estimators): assert('folds' in self.training_base.keys()) models = [] @@ -50,14 +80,27 @@ class ObjectBasedClassifier: 'p_r_f1': precision_recall_fscore_support(y_true, y_pred, zero_division=0) } ) - ''' summary = { 'accuracy_mean': np.mean([x['accuracy'] for x in results]), 'accuracy_std': np.std([x['accuracy'] for x in results]), 'kappa_mean': np.mean([x['kappa'] for x in results]), 'kappa_std': np.std([x['kappa'] for x in results]), + 'prec_mean': np.mean([x['p_r_f1'][0] for x in results], axis=0), + 'prec_std': np.std([x['p_r_f1'][0] for x in results], axis=0), + 'rec_mean': np.mean([x['p_r_f1'][1] for x in results], axis=0), + 'rec_std': np.std([x['p_r_f1'][1] for x in results], axis=0), 'f1_mean': np.mean([x['p_r_f1'][2] for x in results], axis=0), 'f1_std': np.std([x['p_r_f1'][2] for x in results], axis=0) } - ''' - return models, results + return models, summary, results + +#TEST CODE +def run_test(): + obc = ObjectBasedClassifier('/DATA/Moringa_Sample/Parakou/output/segmentation/segmentation.tif', + '/DATA/Moringa_Sample/Parakou/input/REF/ref.shp', + ['/DATA/Moringa_Sample/Parakou/output/S2_processed/T31PDL/*/*FEAT.tif'], + ['/DATA/Moringa_Sample/Parakou/input/THR/THR_SPOT6.tif']) + obc.gen_k_folds(5) + m,s,r = obc.train_RF(100) + print(s) + return -- GitLab