diff --git a/Learning/ObjectBased.py b/Learning/ObjectBased.py index 2f87733522a0e842a7737d734fc94e435ca3314d..b6b65dd1c2510031bc409f838e16ca81631a074b 100644 --- a/Learning/ObjectBased.py +++ b/Learning/ObjectBased.py @@ -93,7 +93,7 @@ class ObjectBasedClassifier: assert('folds' in self.training_base.keys()) models = [] results = [] - yt_yp = [] + truelabs = np.array([]) for tr_i, ts_i in tqdm(self.training_base['folds'], desc='Training'): models.append(RandomForestClassifier(n_estimators=n_estimators)) models[-1].fit(self.training_base['X'][tr_i], self.training_base[class_field][tr_i]) @@ -102,6 +102,7 @@ class ObjectBasedClassifier: c = np.delete(c, np.isin(l, self.training_base['dummy_ids'])) l = np.delete(l, np.isin(l, self.training_base['dummy_ids'])) y_true, y_pred = self.obia_base.true_pred_bypixel(l, c, class_field) + truelabs = np.unique(np.concatenate((truelabs,y_true,y_pred))) results.append( { 'conf_matrix': confusion_matrix(y_true, y_pred, labels=np.unique(self.training_base[class_field])), @@ -127,7 +128,8 @@ class ObjectBasedClassifier: 'f1_mean': np.mean([x['p_r_f1'][2] for x in results], axis=0), 'f1_std': np.std([x['p_r_f1'][2] for x in results], axis=0), 'importance_mean': {k:v for k, v in zip(self.obia_base.get_vars(), np.mean(all_imp, axis=0))}, - 'importance_std': {k:v for k, v in zip(self.obia_base.get_vars(), np.std(all_imp, axis=0))} + 'importance_std': {k:v for k, v in zip(self.obia_base.get_vars(), np.std(all_imp, axis=0))}, + 'actual_labels': list(truelabs) } return models, summary, results diff --git a/Postprocessing/MapFormatting.py b/Postprocessing/MapFormatting.py index 71cdd43bc31a8c8cc755ba698c7074fd43f08ee8..6f9b26bb98ee0f98000829807b2f4e7b90d94300 100644 --- a/Postprocessing/MapFormatting.py +++ b/Postprocessing/MapFormatting.py @@ -1,6 +1,6 @@ import os -def parse_colormap_file(fn): +def parse_colormap_file(fn, filter=None): labels = [] colors = [] class_names = [] @@ -8,6 +8,8 @@ def parse_colormap_file(fn): with open(fn, 'r') as f: for l in f.read().splitlines(): sl = l.split(' ') + if filter is not None and int(sl[0]) not in filter: + continue labels.append(int(sl[0])) colors.append((int(sl[1]),int(sl[2]),int(sl[3]),int(sl[4]))) class_names.append(' '.join(sl[5:])) diff --git a/Postprocessing/Report.py b/Postprocessing/Report.py index 96b2219d8a78b522f5484b0b89ae226953fc87ce..589ae2ec6ba56036226be112aed449ce66fd2fc6 100644 --- a/Postprocessing/Report.py +++ b/Postprocessing/Report.py @@ -30,12 +30,12 @@ def filter_and_order_importance(summary, importance_perc, max_num_var=35): def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name=None, importance_perc=0.75, max_variables=35): - labels, class_names, colors = parse_colormap_file(palette_fn) + labels, class_names, colors = parse_colormap_file(palette_fn, filter=summary['actual_labels']) colors_norm = [(x[0]/255,x[1]/255,x[2]/255,x[3]/255) for x in colors] with plt.ioff(): - #font = {'weight': 'normal', - # 'size': 8} - #plt.rc('font', **font) + font = {'weight': 'normal', + 'size': 8} + plt.rc('font', **font) if not os.path.exists(out_dir): os.makedirs(out_dir) if not isinstance(results, list): @@ -73,7 +73,7 @@ def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name plt.tight_layout() plt.savefig(of['summary'], dpi=300) - imp_m, imp_s, imp_n = filter_and_order_importance(summary, importance_perc) + imp_m, imp_s, imp_n = filter_and_order_importance(summary, importance_perc, max_num_var=max_variables) fig, ax = plt.subplots() ax.barh(range(len(imp_n)), imp_m, xerr=imp_s, align='center') @@ -194,7 +194,7 @@ def generate_pdf(of, out_pdf, name='output'): def generate_text_report(results, summary, palette_fn, output_fn, name='output'): if os.path.dirname(output_fn) != '': os.makedirs(os.path.dirname(output_fn), exist_ok=True) - labels, class_names, _ = parse_colormap_file(palette_fn) + labels, class_names, _ = parse_colormap_file(palette_fn, filter=summary['actual_labels']) lines = ['MORINGA Final Report for chain {}, {}'.format(name, datetime.now().strftime('%Y-%m-%d %Hh%M')), ''] table_lines = []