Commit 192ef4ce authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

ENH: Repeat fold generation process for n retries if inconsistent.

No related merge requests found
Showing with 16 additions and 12 deletions
+16 -12
......@@ -23,18 +23,22 @@ class ObjectBasedClassifier:
self.training_base['folds'] = []
return
def gen_k_folds(self, k, class_field='class'):
sgk = StratifiedGroupKFold(n_splits=k, shuffle=True)
for tr_i, ts_i in sgk.split(self.training_base['X'],
self.training_base[class_field],
self.training_base['groups']):
self.training_base['folds'].append((tr_i, ts_i))
# check if all classes are in all splits
n_classes = len(np.unique(self.training_base[class_field]))
ok = True
for f in self.training_base['folds']:
ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and
len(np.unique(self.training_base[class_field][f[1]])) == n_classes)
def gen_k_folds(self, k, class_field='class', n_retries=10):
ok = False
retry = 0
while (not ok) and retry<n_retries:
self.training_base['folds'] = []
sgk = StratifiedGroupKFold(n_splits=k, shuffle=True)
for tr_i, ts_i in sgk.split(self.training_base['X'],
self.training_base[class_field],
self.training_base['groups']):
self.training_base['folds'].append((tr_i, ts_i))
# check if all classes are in all splits
n_classes = len(np.unique(self.training_base[class_field]))
ok = True
for f in self.training_base['folds']:
ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and
len(np.unique(self.training_base[class_field][f[1]])) == n_classes)
if not ok:
self.training_base['folds'] = []
raise Exception("Not all classes are present in each fold/split.\n"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment