diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py
index 5936243e5eb6a5fddd4a21b9d7f38d24909e1bf8..ee96958c9244630fd2309536bb1ed841769d810b 100644
--- a/Learning/ObjectBased.py
+++ b/Learning/ObjectBased.py
@@ -23,18 +23,22 @@ class ObjectBasedClassifier:
             self.training_base['folds'] = []
         return
 
-    def gen_k_folds(self, k, class_field='class'):
-        sgk = StratifiedGroupKFold(n_splits=k, shuffle=True)
-        for tr_i, ts_i in sgk.split(self.training_base['X'],
-                                    self.training_base[class_field],
-                                    self.training_base['groups']):
-            self.training_base['folds'].append((tr_i, ts_i))
-        # check if all classes are in all splits
-        n_classes = len(np.unique(self.training_base[class_field]))
-        ok = True
-        for f in self.training_base['folds']:
-            ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and
-                   len(np.unique(self.training_base[class_field][f[1]])) == n_classes)
+    def gen_k_folds(self, k, class_field='class', n_retries=10):
+        ok = False
+        retry = 0
+        while (not ok) and retry<n_retries:
+            self.training_base['folds'] = []
+            sgk = StratifiedGroupKFold(n_splits=k, shuffle=True)
+            for tr_i, ts_i in sgk.split(self.training_base['X'],
+                                        self.training_base[class_field],
+                                        self.training_base['groups']):
+                self.training_base['folds'].append((tr_i, ts_i))
+            # check if all classes are in all splits
+            n_classes = len(np.unique(self.training_base[class_field]))
+            ok = True
+            for f in self.training_base['folds']:
+                ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and
+                       len(np.unique(self.training_base[class_field][f[1]])) == n_classes)
         if not ok:
             self.training_base['folds'] = []
             raise Exception("Not all classes are present in each fold/split.\n"