Commit 54410b74 authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

WIP: Coding reporting tools

No related merge requests found
Showing with 56 additions and 17 deletions
+56 -17
......@@ -63,23 +63,30 @@ class ObjectBasedClassifier:
raise Exception("Not all classes are present in each split.\n"
"Please check that you have enough groups (e.g. 2 x (1/test_train_ratio)) per class.")
def train_RF(self, n_estimators):
def train_RF(self, n_estimators, return_true_vs_pred=False):
assert('folds' in self.training_base.keys())
models = []
results = []
for tr_i, ts_i in self.training_base['folds']:
yt_yp = []
for tr_i, ts_i in tqdm(self.training_base['folds'], desc='Training'):
models.append(RandomForestClassifier(n_estimators=n_estimators))
models[-1].fit(self.training_base['X'][tr_i], self.training_base['Y'][tr_i])
l, c = self.training_base['obj_id'][ts_i], models[-1].predict(self.training_base['X'][ts_i])
y_true, y_pred = self.obia_base.true_pred_bypixel(l, c)
results.append(
{
'conf_matrix': confusion_matrix(y_true, y_pred),
'accuracy': accuracy_score(y_true, y_pred),
'kappa' : cohen_kappa_score(y_true, y_pred),
'p_r_f1': precision_recall_fscore_support(y_true, y_pred, zero_division=0)
}
)
if return_true_vs_pred:
yt_yp.append((y_true, y_pred))
else:
results.append(
{
'conf_matrix': confusion_matrix(y_true, y_pred),
'accuracy': accuracy_score(y_true, y_pred),
'kappa' : cohen_kappa_score(y_true, y_pred),
'p_r_f1': precision_recall_fscore_support(y_true, y_pred, zero_division=0)
}
)
if return_true_vs_pred:
return models, yt_yp
else:
summary = {
'accuracy_mean': np.mean([x['accuracy'] for x in results]),
'accuracy_std': np.std([x['accuracy'] for x in results]),
......@@ -92,9 +99,10 @@ class ObjectBasedClassifier:
'f1_mean': np.mean([x['p_r_f1'][2] for x in results], axis=0),
'f1_std': np.std([x['p_r_f1'][2] for x in results], axis=0)
}
return models, summary, results
return models, summary, results
def classify(self, model, output_file=None, compress='NONE'):
prg = tqdm(desc='Classification', total=len(self.obia_base.tiles))
if isinstance(model, list):
for t, L, X in self.obia_base.tiled_data(normalize=[self.training_base['perc2'],
self.training_base['perc98']]):
......@@ -104,11 +112,13 @@ class ObjectBasedClassifier:
prob = np.prod(prob, axis=0)
c = model[0].classes_[np.argmax(prob, axis=1)]
self.obia_base.populate_map(t, L, c, output_file, compress)
prg.update(1)
else:
for t,L,X in self.obia_base.tiled_data(normalize=[self.training_base['perc2'],
self.training_base['perc98']]):
c = model.predict(X)
self.obia_base.populate_map(t, L, c, output_file, compress)
prg.update(1)
return
#TEST CODE
......@@ -124,9 +134,6 @@ def run_test():
glob.glob('/DATA/Benin/OBSYDYA_data/MORINGA/ext_features'))
'''
obc.gen_k_folds(5)
print('Performing Training and Cross-Validation...')
m,s,r = obc.train_RF(100)
print(s)
print('Performing Classification...')
obc.classify(m, '/DATA/Benin/OBSYDYA_data/MORINGAv2/firstmap.tif')
return obc
m,yt_yp = obc.train_RF(100, return_true_vs_pred=True)
obc.classify(m, '/DATA/Moringa_Sample/Parakou/output/classification/firstmap.tif')
return yt_yp
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, classification_report
import numpy as np
import os
def generate_report_figures(yt_yp, out_dir, class_names=None):
with plt.ioff():
if not os.path.exists(out_dir):
os.makedirs(out_dir)
if not isinstance(yt_yp, list):
yt_yp = [yt_yp]
of = {}
of['conf_matrices'] = []
for i,r in enumerate(yt_yp):
cm = ConfusionMatrixDisplay.from_predictions(r[0], r[1], display_labels=class_names, normalize='true',
include_values=True, xticks_rotation=45)
of['conf_matrices'].append('{}/conf_matrix_{}.eps'.format(out_dir, str(i).zfill(2)))
cm.figure_.savefig(of['conf_matrices'][-1], bbox_inches='tight')
of['cl_rep'] = []
summary = []
for r in yt_yp:
of['cl_rep'].append(classification_report(r[0], r[1], output_dict=True, target_names=class_names))
of['summary']
fsc = [np.array([x[c]['f1-score'] for x in of['cl_rep']]) for c in class_names]
fsc_m = [np.mean(x) for x in fsc]
fsc_s = [np.std(x) for x in fsc]
return fsc
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment