From 192ef4ced8b84af0614bc92c9783ff4ffac79ad3 Mon Sep 17 00:00:00 2001 From: "raffaele.gaetano" <raffaele.gaetano@cirad.fr> Date: Wed, 19 Jul 2023 10:34:49 +0200 Subject: [PATCH] ENH: Repeat fold generation process for n retries if inconsistent. --- Learning/ObjectBased.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py index 5936243..ee96958 100644 --- a/Learning/ObjectBased.py +++ b/Learning/ObjectBased.py @@ -23,18 +23,22 @@ class ObjectBasedClassifier: self.training_base['folds'] = [] return - def gen_k_folds(self, k, class_field='class'): - sgk = StratifiedGroupKFold(n_splits=k, shuffle=True) - for tr_i, ts_i in sgk.split(self.training_base['X'], - self.training_base[class_field], - self.training_base['groups']): - self.training_base['folds'].append((tr_i, ts_i)) - # check if all classes are in all splits - n_classes = len(np.unique(self.training_base[class_field])) - ok = True - for f in self.training_base['folds']: - ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and - len(np.unique(self.training_base[class_field][f[1]])) == n_classes) + def gen_k_folds(self, k, class_field='class', n_retries=10): + ok = False + retry = 0 + while (not ok) and retry<n_retries: + self.training_base['folds'] = [] + sgk = StratifiedGroupKFold(n_splits=k, shuffle=True) + for tr_i, ts_i in sgk.split(self.training_base['X'], + self.training_base[class_field], + self.training_base['groups']): + self.training_base['folds'].append((tr_i, ts_i)) + # check if all classes are in all splits + n_classes = len(np.unique(self.training_base[class_field])) + ok = True + for f in self.training_base['folds']: + ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and + len(np.unique(self.training_base[class_field][f[1]])) == n_classes) if not ok: self.training_base['folds'] = [] raise Exception("Not all classes are present in each fold/split.\n" -- GitLab