diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py index 5936243e5eb6a5fddd4a21b9d7f38d24909e1bf8..ee96958c9244630fd2309536bb1ed841769d810b 100644 --- a/Learning/ObjectBased.py +++ b/Learning/ObjectBased.py @@ -23,18 +23,22 @@ class ObjectBasedClassifier: self.training_base['folds'] = [] return - def gen_k_folds(self, k, class_field='class'): - sgk = StratifiedGroupKFold(n_splits=k, shuffle=True) - for tr_i, ts_i in sgk.split(self.training_base['X'], - self.training_base[class_field], - self.training_base['groups']): - self.training_base['folds'].append((tr_i, ts_i)) - # check if all classes are in all splits - n_classes = len(np.unique(self.training_base[class_field])) - ok = True - for f in self.training_base['folds']: - ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and - len(np.unique(self.training_base[class_field][f[1]])) == n_classes) + def gen_k_folds(self, k, class_field='class', n_retries=10): + ok = False + retry = 0 + while (not ok) and retry<n_retries: + self.training_base['folds'] = [] + sgk = StratifiedGroupKFold(n_splits=k, shuffle=True) + for tr_i, ts_i in sgk.split(self.training_base['X'], + self.training_base[class_field], + self.training_base['groups']): + self.training_base['folds'].append((tr_i, ts_i)) + # check if all classes are in all splits + n_classes = len(np.unique(self.training_base[class_field])) + ok = True + for f in self.training_base['folds']: + ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and + len(np.unique(self.training_base[class_field][f[1]])) == n_classes) if not ok: self.training_base['folds'] = [] raise Exception("Not all classes are present in each fold/split.\n"