Commit 06fd3f52 authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

FIX: Filter classes in palette file using actual true and pred classes

parent 2536dedc
No related merge requests found
Showing with 13 additions and 9 deletions
+13 -9
...@@ -93,7 +93,7 @@ class ObjectBasedClassifier: ...@@ -93,7 +93,7 @@ class ObjectBasedClassifier:
assert('folds' in self.training_base.keys()) assert('folds' in self.training_base.keys())
models = [] models = []
results = [] results = []
yt_yp = [] truelabs = np.array([])
for tr_i, ts_i in tqdm(self.training_base['folds'], desc='Training'): for tr_i, ts_i in tqdm(self.training_base['folds'], desc='Training'):
models.append(RandomForestClassifier(n_estimators=n_estimators)) models.append(RandomForestClassifier(n_estimators=n_estimators))
models[-1].fit(self.training_base['X'][tr_i], self.training_base[class_field][tr_i]) models[-1].fit(self.training_base['X'][tr_i], self.training_base[class_field][tr_i])
...@@ -102,6 +102,7 @@ class ObjectBasedClassifier: ...@@ -102,6 +102,7 @@ class ObjectBasedClassifier:
c = np.delete(c, np.isin(l, self.training_base['dummy_ids'])) c = np.delete(c, np.isin(l, self.training_base['dummy_ids']))
l = np.delete(l, np.isin(l, self.training_base['dummy_ids'])) l = np.delete(l, np.isin(l, self.training_base['dummy_ids']))
y_true, y_pred = self.obia_base.true_pred_bypixel(l, c, class_field) y_true, y_pred = self.obia_base.true_pred_bypixel(l, c, class_field)
truelabs = np.unique(np.concatenate((truelabs,y_true,y_pred)))
results.append( results.append(
{ {
'conf_matrix': confusion_matrix(y_true, y_pred, labels=np.unique(self.training_base[class_field])), 'conf_matrix': confusion_matrix(y_true, y_pred, labels=np.unique(self.training_base[class_field])),
...@@ -127,7 +128,8 @@ class ObjectBasedClassifier: ...@@ -127,7 +128,8 @@ class ObjectBasedClassifier:
'f1_mean': np.mean([x['p_r_f1'][2] for x in results], axis=0), 'f1_mean': np.mean([x['p_r_f1'][2] for x in results], axis=0),
'f1_std': np.std([x['p_r_f1'][2] for x in results], axis=0), 'f1_std': np.std([x['p_r_f1'][2] for x in results], axis=0),
'importance_mean': {k:v for k, v in zip(self.obia_base.get_vars(), np.mean(all_imp, axis=0))}, 'importance_mean': {k:v for k, v in zip(self.obia_base.get_vars(), np.mean(all_imp, axis=0))},
'importance_std': {k:v for k, v in zip(self.obia_base.get_vars(), np.std(all_imp, axis=0))} 'importance_std': {k:v for k, v in zip(self.obia_base.get_vars(), np.std(all_imp, axis=0))},
'actual_labels': list(truelabs)
} }
return models, summary, results return models, summary, results
......
import os import os
def parse_colormap_file(fn): def parse_colormap_file(fn, filter=None):
labels = [] labels = []
colors = [] colors = []
class_names = [] class_names = []
...@@ -8,6 +8,8 @@ def parse_colormap_file(fn): ...@@ -8,6 +8,8 @@ def parse_colormap_file(fn):
with open(fn, 'r') as f: with open(fn, 'r') as f:
for l in f.read().splitlines(): for l in f.read().splitlines():
sl = l.split(' ') sl = l.split(' ')
if filter is not None and int(sl[0]) not in filter:
continue
labels.append(int(sl[0])) labels.append(int(sl[0]))
colors.append((int(sl[1]),int(sl[2]),int(sl[3]),int(sl[4]))) colors.append((int(sl[1]),int(sl[2]),int(sl[3]),int(sl[4])))
class_names.append(' '.join(sl[5:])) class_names.append(' '.join(sl[5:]))
......
...@@ -30,12 +30,12 @@ def filter_and_order_importance(summary, importance_perc, max_num_var=35): ...@@ -30,12 +30,12 @@ def filter_and_order_importance(summary, importance_perc, max_num_var=35):
def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name=None, def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name=None,
importance_perc=0.75, max_variables=35): importance_perc=0.75, max_variables=35):
labels, class_names, colors = parse_colormap_file(palette_fn) labels, class_names, colors = parse_colormap_file(palette_fn, filter=summary['actual_labels'])
colors_norm = [(x[0]/255,x[1]/255,x[2]/255,x[3]/255) for x in colors] colors_norm = [(x[0]/255,x[1]/255,x[2]/255,x[3]/255) for x in colors]
with plt.ioff(): with plt.ioff():
#font = {'weight': 'normal', font = {'weight': 'normal',
# 'size': 8} 'size': 8}
#plt.rc('font', **font) plt.rc('font', **font)
if not os.path.exists(out_dir): if not os.path.exists(out_dir):
os.makedirs(out_dir) os.makedirs(out_dir)
if not isinstance(results, list): if not isinstance(results, list):
...@@ -73,7 +73,7 @@ def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name ...@@ -73,7 +73,7 @@ def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name
plt.tight_layout() plt.tight_layout()
plt.savefig(of['summary'], dpi=300) plt.savefig(of['summary'], dpi=300)
imp_m, imp_s, imp_n = filter_and_order_importance(summary, importance_perc) imp_m, imp_s, imp_n = filter_and_order_importance(summary, importance_perc, max_num_var=max_variables)
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.barh(range(len(imp_n)), imp_m, xerr=imp_s, align='center') ax.barh(range(len(imp_n)), imp_m, xerr=imp_s, align='center')
...@@ -194,7 +194,7 @@ def generate_pdf(of, out_pdf, name='output'): ...@@ -194,7 +194,7 @@ def generate_pdf(of, out_pdf, name='output'):
def generate_text_report(results, summary, palette_fn, output_fn, name='output'): def generate_text_report(results, summary, palette_fn, output_fn, name='output'):
if os.path.dirname(output_fn) != '': if os.path.dirname(output_fn) != '':
os.makedirs(os.path.dirname(output_fn), exist_ok=True) os.makedirs(os.path.dirname(output_fn), exist_ok=True)
labels, class_names, _ = parse_colormap_file(palette_fn) labels, class_names, _ = parse_colormap_file(palette_fn, filter=summary['actual_labels'])
lines = ['MORINGA Final Report for chain {}, {}'.format(name, datetime.now().strftime('%Y-%m-%d %Hh%M')), ''] lines = ['MORINGA Final Report for chain {}, {}'.format(name, datetime.now().strftime('%Y-%m-%d %Hh%M')), '']
table_lines = [] table_lines = []
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment