From a1b2d34220f753a54cf5553bc84d181c5b1a558a Mon Sep 17 00:00:00 2001
From: Raffaele Gaetano <raffaele.gaetano@cirad.fr>
Date: Thu, 18 Jan 2024 15:28:13 +0100
Subject: [PATCH] ENH: optional automatic augmentation if not enough samples.

---
 Learning/ObjectBased.py      | 64 +++++++++++++++++++++++++-----------
 Learning/SampleManagement.py | 38 +++++++++++++++++++++
 Postprocessing/Report.py     | 15 +++++----
 Workflows/basic.py           |  2 +-
 Workflows/basic_config.json  |  3 +-
 5 files changed, 95 insertions(+), 27 deletions(-)
 create mode 100644 Learning/SampleManagement.py

diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py
index 06c8a73..8ef0e48 100644
--- a/Learning/ObjectBased.py
+++ b/Learning/ObjectBased.py
@@ -5,10 +5,14 @@ import numpy as np
 import pandas as pd
 
 from OBIA.OBIABase import *
-from sklearn.model_selection import StratifiedGroupKFold, GroupShuffleSplit
+from sklearn.model_selection import GroupShuffleSplit
 from sklearn.ensemble import RandomForestClassifier
 from sklearn.metrics import confusion_matrix, accuracy_score, cohen_kappa_score, precision_recall_fscore_support
 
+from Learning.SampleManagement import gen_k_folds, generate_samples_from_set
+
+import warnings
+
 class ObjectBasedClassifier:
     def __init__(self, object_layer, time_series_list, user_feature_list,
                  reference_data=None, ref_class_field='class', ref_id_field='id'):
@@ -22,31 +26,50 @@ class ObjectBasedClassifier:
             self.obia_base.populate_ref_db()
             self.training_base = self.obia_base.get_reference_db_as_training_base(class_field=ref_class_field)
             self.training_base['folds'] = []
+            self.training_base['dummy_ids'] = []
         return
 
-    def gen_k_folds(self, k, class_field='class', n_retries=10):
+    def gen_k_folds(self, k, class_field='class', n_retries=10, augment=False):
         ok = False
         retry = 0
         while (not ok) and retry<n_retries:
-            self.training_base['folds'] = []
-            sgk = StratifiedGroupKFold(n_splits=k, shuffle=True)
-            for tr_i, ts_i in sgk.split(self.training_base['X'],
-                                        self.training_base[class_field],
-                                        self.training_base['groups']):
-                self.training_base['folds'].append((tr_i, ts_i))
-            # check if all classes are in all splits
-            n_classes = len(np.unique(self.training_base[class_field]))
-            ok = True
-            for f in self.training_base['folds']:
-                ok &= (len(np.unique(self.training_base[class_field][f[0]])) == n_classes and
-                       len(np.unique(self.training_base[class_field][f[1]])) == n_classes)
+            self.training_base['folds'], ok, problematic = gen_k_folds(self.training_base['X'],
+                                                                       self.training_base[class_field],
+                                                                       self.training_base['groups'], k)
             retry += 1
         if not ok:
-            self.training_base['folds'] = []
-            raise Exception("Not all classes are present in each fold/split.\n"
-                            "Please check that you have enough groups (e.g. 2 x n_folds) per class.")
+            if not augment:
+                raise Exception("Not all classes are present in each fold/split.\n"
+                                "Please check that you have enough polygons (e.g. 2 x n_folds) per class.")
+            else:
+                warnings.warn('Classes {} have not enough groups to ensure sample presence in each fold. Augmenting to 2 x n_folds samples.'.format(problematic))
+                n_samples_to_add = [2*k - len(np.unique(self.training_base['groups'][self.training_base[class_field]==c])) for c in problematic]
+                curr_grp = np.max(self.training_base['groups']) + 1
+                curr_id = np.max(self.training_base['obj_id']) + 1
+                for c,n in zip(problematic, n_samples_to_add):
+                    x = self.training_base['X'][self.training_base[class_field]==c]
+                    s = generate_samples_from_set(x,n,0.01)
+                    sc = c * np.ones(n)
+                    sg = curr_grp + np.arange(n)
+                    sid = curr_id + np.arange(n)
+                    self.training_base['X'] = np.vstack([self.training_base['X'], s])
+                    self.training_base[class_field] = np.concatenate([self.training_base[class_field], sc])
+                    self.training_base['groups'] = np.concatenate([self.training_base['groups'], sg])
+                    self.training_base['obj_id'] = np.concatenate([self.training_base['obj_id'], sid])
+                    self.training_base['dummy_ids'] = np.concatenate([self.training_base['dummy_ids'], sid])
+                    curr_grp += n
+                    curr_id += n
+                retry = 0
+                while (not ok) and retry<n_retries:
+                    self.training_base['folds'], ok, problematic = gen_k_folds(self.training_base['X'],
+                                                                           self.training_base[class_field],
+                                                                           self.training_base['groups'], k)
+                    retry += 1
+                if not ok:
+                    raise Exception("Still not ok after augmentation. Please provide more samples.")
         return
 
+    # To change!
     def gen_hold_out(self, test_train_ratio=0.5, n_splits=1, class_field='class'):
         gss = GroupShuffleSplit(n_splits=n_splits, test_size=test_train_ratio)
         for tr_i, ts_i in gss.split(self.training_base['X'],
@@ -73,13 +96,16 @@ class ObjectBasedClassifier:
             models.append(RandomForestClassifier(n_estimators=n_estimators))
             models[-1].fit(self.training_base['X'][tr_i], self.training_base[class_field][tr_i])
             l, c = self.training_base['obj_id'][ts_i], models[-1].predict(self.training_base['X'][ts_i])
+            # Remove dummy ids and relative class label (can lead to no samples in test set!)
+            c = np.delete(c, np.isin(l, self.training_base['dummy_ids']))
+            l = np.delete(l, np.isin(l, self.training_base['dummy_ids']))
             y_true, y_pred = self.obia_base.true_pred_bypixel(l, c, class_field)
             results.append(
                 {
-                    'conf_matrix': confusion_matrix(y_true, y_pred),
+                    'conf_matrix': confusion_matrix(y_true, y_pred, labels=np.unique(self.training_base[class_field])),
                     'accuracy': accuracy_score(y_true, y_pred),
                     'kappa' : cohen_kappa_score(y_true, y_pred),
-                    'p_r_f1': precision_recall_fscore_support(y_true, y_pred, zero_division=0),
+                    'p_r_f1': precision_recall_fscore_support(y_true, y_pred, zero_division=0, labels=np.unique(self.training_base[class_field])),
                     'importances' : models[-1].feature_importances_
                 }
             )
diff --git a/Learning/SampleManagement.py b/Learning/SampleManagement.py
new file mode 100644
index 0000000..12eed54
--- /dev/null
+++ b/Learning/SampleManagement.py
@@ -0,0 +1,38 @@
+import numpy as np
+
+from sklearn.model_selection import StratifiedGroupKFold
+
+def gen_k_folds(X, Y, G, k):
+    folds = []
+    sgk = StratifiedGroupKFold(n_splits=k, shuffle=True)
+    for tr_i, ts_i in sgk.split(X, Y, G):
+        folds.append((tr_i, ts_i))
+    # check if all classes are in all splits
+    problematic = []
+    for f in folds:
+        problematic.extend([
+            np.setdiff1d(
+                np.unique(Y),
+                np.unique(Y[f[0]])
+                ),
+            np.setdiff1d(
+                np.unique(Y),
+                np.unique(Y[f[1]])
+                )
+            ])
+    ok = all([x.size == 0 for x in problematic])
+    problematic = np.unique(np.concatenate(problematic))
+    if not ok:
+        folds = []
+    return folds, ok, problematic
+
+def generate_samples_from_set(X, num_samples=1, sigma_noise=0.0):
+    M = np.mean(X, axis=0)
+    C = np.cov(X, rowvar=False)
+    S = np.random.multivariate_normal(M,C,size=num_samples)
+    if sigma_noise > 0.0:
+        M = np.zeros(X.shape[1])
+        C = np.array(np.diag((sigma_noise**2) * np.ones(X.shape[1])))
+        N = np.random.multivariate_normal(M,C,size=num_samples)
+        S += N
+    return S
\ No newline at end of file
diff --git a/Postprocessing/Report.py b/Postprocessing/Report.py
index cd8509d..96b2219 100644
--- a/Postprocessing/Report.py
+++ b/Postprocessing/Report.py
@@ -43,7 +43,7 @@ def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name
         of = {}
         of['conf_matrices'] = []
         for i,r in enumerate(results):
-            cm = ConfusionMatrixDisplay.from_predictions(r['true_vs_pred'][0], r['true_vs_pred'][1],
+            cm = ConfusionMatrixDisplay.from_predictions(r['true_vs_pred'][0], r['true_vs_pred'][1], labels=labels,
                                                          normalize='true', include_values=True, values_format='.2f')
             of['conf_matrices'].append('{}/conf_matrix_{}.png'.format(out_dir, str(i).zfill(2)))
             cm.ax_.set_xticklabels(class_names, rotation=45, ha='right')
@@ -55,8 +55,8 @@ def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name
 
         of['cl_rep'] = []
         for r in results:
-            of['cl_rep'].append(classification_report(r['true_vs_pred'][0], r['true_vs_pred'][1],
-                                                      output_dict=True, target_names=class_names))
+            of['cl_rep'].append(classification_report(r['true_vs_pred'][0], r['true_vs_pred'][1], labels=labels,
+                                                      output_dict=True, target_names=class_names, zero_division=0))
 
         fsc = [np.array([x[c]['f1-score'] for x in of['cl_rep']]) for c in class_names]
         fsc_m = [np.mean(x) for x in fsc]
@@ -182,9 +182,12 @@ def generate_pdf(of, out_pdf, name='output'):
                     row.cell('{:.4f}'.format(datum), align='R')
                 row.cell('{}'.format(data_row[-1]), align='R')
             row = table.row()
-            row.cell('Accuracy')
-            row.cell('{:.2f}%'.format(rep['accuracy']*100), align='R')
-
+            if 'accuracy' in rep.keys():
+                row.cell('Accuracy')
+                row.cell('{:.2f}%'.format(rep['accuracy']*100), align='R')
+            elif 'micro avg' in rep.keys():
+                row.cell('Micro Avg. F1-Score')
+                row.cell('{:.2f}%'.format(rep['micro avg']['f1-score']*100), align='R')
 
     pdf.output(out_pdf)
 
diff --git a/Workflows/basic.py b/Workflows/basic.py
index 4f04044..81fd682 100644
--- a/Workflows/basic.py
+++ b/Workflows/basic.py
@@ -84,7 +84,7 @@ def train_valid_workflow(seg, ts_lst_pkl, d, m_file):
                                 reference_data=d['ref_db']['path'],
                                 ref_class_field=d['ref_db']['fields'])
 
-    obc.gen_k_folds(5, class_field=d['ref_db']['fields'][-1])
+    obc.gen_k_folds(5, class_field=d['ref_db']['fields'][-1],augment=d['training']['augment_if_missing'])
 
     if 'export_training_base' in d['training'].keys() and d['training']['export_training_base'] is True:
         obc.save_training_base('{}/_side/training_base.pkl'.format(os.path.join(d['output_path'], d['chain_name'])))
diff --git a/Workflows/basic_config.json b/Workflows/basic_config.json
index 147d8a6..54d8575 100644
--- a/Workflows/basic_config.json
+++ b/Workflows/basic_config.json
@@ -8,7 +8,8 @@
 	
 	"ref_db" : {
 		"path": "/path/to/ref/db/vector",
-		"fields": ["class_field_1", "class_field_2"]
+		"fields": ["class_field_1", "class_field_2"],
+		"augment_if_missing": false
 	},
 
 	"dem" : {
-- 
GitLab