Commit 06fd3f52 authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

FIX: Filter classes in palette file using actual true and pred classes

parent 2536dedc
No related merge requests found
Showing with 13 additions and 9 deletions
+13 -9
......@@ -93,7 +93,7 @@ class ObjectBasedClassifier:
assert('folds' in self.training_base.keys())
models = []
results = []
yt_yp = []
truelabs = np.array([])
for tr_i, ts_i in tqdm(self.training_base['folds'], desc='Training'):
models.append(RandomForestClassifier(n_estimators=n_estimators))
models[-1].fit(self.training_base['X'][tr_i], self.training_base[class_field][tr_i])
......@@ -102,6 +102,7 @@ class ObjectBasedClassifier:
c = np.delete(c, np.isin(l, self.training_base['dummy_ids']))
l = np.delete(l, np.isin(l, self.training_base['dummy_ids']))
y_true, y_pred = self.obia_base.true_pred_bypixel(l, c, class_field)
truelabs = np.unique(np.concatenate((truelabs,y_true,y_pred)))
results.append(
{
'conf_matrix': confusion_matrix(y_true, y_pred, labels=np.unique(self.training_base[class_field])),
......@@ -127,7 +128,8 @@ class ObjectBasedClassifier:
'f1_mean': np.mean([x['p_r_f1'][2] for x in results], axis=0),
'f1_std': np.std([x['p_r_f1'][2] for x in results], axis=0),
'importance_mean': {k:v for k, v in zip(self.obia_base.get_vars(), np.mean(all_imp, axis=0))},
'importance_std': {k:v for k, v in zip(self.obia_base.get_vars(), np.std(all_imp, axis=0))}
'importance_std': {k:v for k, v in zip(self.obia_base.get_vars(), np.std(all_imp, axis=0))},
'actual_labels': list(truelabs)
}
return models, summary, results
......
import os
def parse_colormap_file(fn):
def parse_colormap_file(fn, filter=None):
labels = []
colors = []
class_names = []
......@@ -8,6 +8,8 @@ def parse_colormap_file(fn):
with open(fn, 'r') as f:
for l in f.read().splitlines():
sl = l.split(' ')
if filter is not None and int(sl[0]) not in filter:
continue
labels.append(int(sl[0]))
colors.append((int(sl[1]),int(sl[2]),int(sl[3]),int(sl[4])))
class_names.append(' '.join(sl[5:]))
......
......@@ -30,12 +30,12 @@ def filter_and_order_importance(summary, importance_perc, max_num_var=35):
def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name=None,
importance_perc=0.75, max_variables=35):
labels, class_names, colors = parse_colormap_file(palette_fn)
labels, class_names, colors = parse_colormap_file(palette_fn, filter=summary['actual_labels'])
colors_norm = [(x[0]/255,x[1]/255,x[2]/255,x[3]/255) for x in colors]
with plt.ioff():
#font = {'weight': 'normal',
# 'size': 8}
#plt.rc('font', **font)
font = {'weight': 'normal',
'size': 8}
plt.rc('font', **font)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
if not isinstance(results, list):
......@@ -73,7 +73,7 @@ def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name
plt.tight_layout()
plt.savefig(of['summary'], dpi=300)
imp_m, imp_s, imp_n = filter_and_order_importance(summary, importance_perc)
imp_m, imp_s, imp_n = filter_and_order_importance(summary, importance_perc, max_num_var=max_variables)
fig, ax = plt.subplots()
ax.barh(range(len(imp_n)), imp_m, xerr=imp_s, align='center')
......@@ -194,7 +194,7 @@ def generate_pdf(of, out_pdf, name='output'):
def generate_text_report(results, summary, palette_fn, output_fn, name='output'):
if os.path.dirname(output_fn) != '':
os.makedirs(os.path.dirname(output_fn), exist_ok=True)
labels, class_names, _ = parse_colormap_file(palette_fn)
labels, class_names, _ = parse_colormap_file(palette_fn, filter=summary['actual_labels'])
lines = ['MORINGA Final Report for chain {}, {}'.format(name, datetime.now().strftime('%Y-%m-%d %Hh%M')), '']
table_lines = []
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment