Commit 705a3e4a authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

ENH: ensuring robust stratified group k-fold generation.

parent dc43c50f
No related merge requests found
Showing with 49 additions and 6 deletions
+49 -6
import glob
import numpy as np
from OBIA.OBIABase import *
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.model_selection import StratifiedGroupKFold, GroupShuffleSplit
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, accuracy_score, cohen_kappa_score, precision_recall_fscore_support
......@@ -20,19 +23,46 @@ class ObjectBasedClassifier:
'Y': Y,
'groups': G,
'perc2': p2,
'perc98': p98
'perc98': p98,
'folds': []
}
return
def gen_k_folds(self, k):
self.training_base['folds'] = []
sgk = StratifiedGroupKFold(n_splits=k, shuffle=True)
for tr_i, ts_i in sgk.split(self.training_base['X'],
self.training_base['Y'],
self.training_base['groups']):
self.training_base['folds'].append((tr_i, ts_i))
# check if all classes are in all splits
n_classes = len(np.unique(self.training_base['Y']))
ok = True
for f in self.training_base['folds']:
ok &= (len(np.unique(self.training_base['Y'][f[0]])) == n_classes and
len(np.unique(self.training_base['Y'][f[1]])) == n_classes)
if not ok:
self.training_base['folds'] = []
raise Exception("Not all classes are present in each fold/split.\n"
"Please check that you have enough groups (e.g. 2 x n_folds) per class.")
return
def gen_hold_out(self, test_train_ratio=0.5, n_splits=1):
gss = GroupShuffleSplit(n_splits=n_splits, test_size=test_train_ratio)
for tr_i, ts_i in gss.split(self.training_base['X'],
self.training_base['Y'],
self.training_base['groups']):
self.training_base['folds'].append((tr_i, ts_i))
# check if all classes are in all splits
n_classes = len(np.unique(self.training_base['Y']))
ok = True
for f in self.training_base['folds']:
ok &= (len(np.unique(self.training_base['Y'][f[0]])) == n_classes and
len(np.unique(self.training_base['Y'][f[1]])) == n_classes)
if not ok:
self.training_base['folds'] = []
raise Exception("Not all classes are present in each split.\n"
"Please check that you have enough groups (e.g. 2 x (1/test_train_ratio)) per class.")
def train_RF(self, n_estimators):
assert('folds' in self.training_base.keys())
models = []
......@@ -50,14 +80,27 @@ class ObjectBasedClassifier:
'p_r_f1': precision_recall_fscore_support(y_true, y_pred, zero_division=0)
}
)
'''
summary = {
'accuracy_mean': np.mean([x['accuracy'] for x in results]),
'accuracy_std': np.std([x['accuracy'] for x in results]),
'kappa_mean': np.mean([x['kappa'] for x in results]),
'kappa_std': np.std([x['kappa'] for x in results]),
'prec_mean': np.mean([x['p_r_f1'][0] for x in results], axis=0),
'prec_std': np.std([x['p_r_f1'][0] for x in results], axis=0),
'rec_mean': np.mean([x['p_r_f1'][1] for x in results], axis=0),
'rec_std': np.std([x['p_r_f1'][1] for x in results], axis=0),
'f1_mean': np.mean([x['p_r_f1'][2] for x in results], axis=0),
'f1_std': np.std([x['p_r_f1'][2] for x in results], axis=0)
}
'''
return models, results
return models, summary, results
#TEST CODE
def run_test():
obc = ObjectBasedClassifier('/DATA/Moringa_Sample/Parakou/output/segmentation/segmentation.tif',
'/DATA/Moringa_Sample/Parakou/input/REF/ref.shp',
['/DATA/Moringa_Sample/Parakou/output/S2_processed/T31PDL/*/*FEAT.tif'],
['/DATA/Moringa_Sample/Parakou/input/THR/THR_SPOT6.tif'])
obc.gen_k_folds(5)
m,s,r = obc.train_RF(100)
print(s)
return
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment