Report.py 8.01 KiB
import matplotlib
matplotlib.use('Agg')
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, classification_report, accuracy_score, cohen_kappa_score
from fpdf import FPDF
import numpy as np
import rasterio
from rasterio.enums import Resampling
from pyproj import Transformer as T
from datetime import datetime
import os
from Postprocessing.MapFormatting import parse_colormap_file

def generate_report_figures(map, palette_fn, results, summary, out_dir, map_name=None,
                            importance_perc=0.75, max_variables=35):
    labels, class_names, colors = parse_colormap_file(palette_fn)
    colors_norm = [(x[0]/255,x[1]/255,x[2]/255,x[3]/255) for x in colors]
    with plt.ioff():
        #font = {'weight': 'normal',
        #        'size': 8}
        #plt.rc('font', **font)
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        if not isinstance(results, list):
            results = [results]
        of = {}
        of['conf_matrices'] = []
        for i,r in enumerate(results):
            cm = ConfusionMatrixDisplay.from_predictions(r['true_vs_pred'][0], r['true_vs_pred'][1],
                                                         normalize='true', include_values=True, values_format='.2f')
            of['conf_matrices'].append('{}/conf_matrix_{}.png'.format(out_dir, str(i).zfill(2)))
            cm.ax_.set_xticklabels(class_names, rotation=45, ha='right')
            cm.ax_.set_yticklabels(class_names)
            cm.ax_.set_xlabel("Predicted label", labelpad=10)
            cm.ax_.set_ylabel("True label", labelpad=15)
            cm.figure_.tight_layout()
            cm.figure_.savefig(of['conf_matrices'][-1], dpi=300)

        of['cl_rep'] = []
        for r in results:
            of['cl_rep'].append(classification_report(r['true_vs_pred'][0], r['true_vs_pred'][1],
                                                      output_dict=True, target_names=class_names))

        fsc = [np.array([x[c]['f1-score'] for x in of['cl_rep']]) for c in class_names]
        fsc_m = [np.mean(x) for x in fsc]
        fsc_s = [np.std(x) for x in fsc]
        fig, ax = plt.subplots()
        ax.bar(range(len(class_names)), fsc_m, yerr=fsc_s, align="center", width=0.3,
               ecolor='black', capsize=10, color=colors_norm)
        ax.set_xticks(range(len(class_names)))
        ax.set_xticklabels(class_names, rotation=45, ha='right')
        ax.set_title('Per-class F1-scores', fontsize=12, fontweight='bold', pad=10)
        ax.yaxis.grid(True)
        of['summary'] = '{}/f1scores.png'.format(out_dir)
        fig.set_figwidth(4)
        plt.tight_layout()
        plt.savefig(of['summary'], dpi=300)

        imp_m = list(summary['importance_mean'].values())
        imp_s = list(summary['importance_std'].values())
        imp_n = list(summary['importance_mean'].keys())
        imp_n = [x for _, x in sorted(zip(imp_m, imp_n), reverse=True)]
        imp_s = [x for _, x in sorted(zip(imp_m, imp_s), reverse=True)]
        imp_m = sorted(imp_m, reverse=True)
        c_imp = np.cumsum(imp_m)
        idx = np.where(c_imp<importance_perc * c_imp[-1])[0][-1]
        idx = min(idx, 35)
        imp_m = imp_m[:idx]
        imp_s = imp_s[:idx]
        imp_n = imp_n[:idx]

        fig, ax = plt.subplots()
        ax.barh(range(len(imp_n)), imp_m, xerr=imp_s, align='center')
        ax.set_yticks(range(len(imp_n)))
        ax.set_yticklabels(imp_n, fontsize=6)
        ax.invert_yaxis()
        ax.set_title('Feature Importances')
        ax.set_xlabel('Mean Decrease in Impurity')
        ax.xaxis.grid(True)
        of['importances'] = '{}/importances.png'.format(out_dir)
        fig.set_figwidth(5)
        plt.tight_layout()
        plt.savefig(of['importances'], dpi=300)

        if map_name is None:
            map_name = 'output'
        of['quicklook'] = create_map_quicklook_and_legend(map, labels, colors, class_names, results, out_dir, map_name)

        return of

def create_map_quicklook_and_legend(map, labels, colors, class_names, results, out_dir, name='', qkl_height=1024):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    oa,k = [],[]
    if not isinstance(results, list):
        results = [results]
    for r in results:
        oa.append(r['accuracy'])
        k.append(r['kappa'])
    oam, oas = np.mean(np.array(oa)), np.std(np.array(oa))
    km, ks = np.mean(np.array(k)), np.std(np.array(k))
    with rasterio.open(map) as ds:
        tr = ds.transform
        crs = ds.crs.to_epsg()
        smap = ds.read(1,
            out_shape=(1, qkl_height, int(ds.width * (qkl_height/ds.height))),
            resampling=Resampling.nearest
        )
    cmap = np.zeros((np.max(labels)+1,4))
    cmap[labels] = np.array([np.array(list(c)) for c in colors])
    img = cmap[smap].astype(int)

    font = {'weight': 'normal',
            'size': 6}
    plt.rc('font', **font)
    fig, ax = plt.subplots()
    ax.imshow(img)
    custom_leg = [Rectangle([0,0], 0, 0, fill=True, edgecolor=(0,0,0),
                            facecolor=([x[0]/255,x[1]/255,x[2]/255,x[3]/255])) for x in colors]
    ax.legend(custom_leg, class_names, loc='center left', bbox_to_anchor=(1, 0.5), fontsize=6)
    conv = T.from_crs(crs, 4326, always_xy=True)
    xt, yt = ax.get_xticks(), ax.get_yticks()
    xpos = ['{:.3f}°'.format(conv.transform(tr[2] + tr[0]*x,0)[0]) for x in xt[1:-1]]
    ypos = ['{:.3f}°'.format(conv.transform(0, tr[5] + tr[4] * x)[1]) for x in yt[1:-1]]
    ax.set_xticks(xt[1:-1])
    ax.set_yticks(yt[1:-1])
    ax.set_xticklabels(xpos, rotation=45, ha='right')
    ax.set_yticklabels(ypos)
    ax.set_title("Final Classification Map Quicklook", fontsize=12, fontweight='bold', pad=15)
    ax.set_xlabel("Overall Map Accuracy : {:.2f} +/- {:.2f}\nCohen's Kappa : {:.2f} +/- {:.2f}".format(oam,oas,km,ks),
                  fontsize=10, fontweight='bold', labelpad=15)
    plt.tight_layout()
    out_fn = '{}/{}_QKL.png'.format(out_dir,name)
    plt.savefig(out_fn, dpi=300)
    return out_fn

def generate_pdf(of, out_pdf, name='output'):
    pdf = FPDF('P', 'mm', 'A4')
    pdf.set_font("helvetica", 'B', 16)
    # Page 1 - Summary (OA/Kappa/Quickview)
    pdf.add_page()
    pdf.set_xy(0,16)
    txt = 'Moringa Final Report for Chain {}, {}'.format(name, datetime.now().strftime('%Y-%m-%d %Hh%M'))
    pdf.cell(0, txt=txt, align='C')
    pdf.image(of['quicklook'], 14, 24, h=140)
    pdf.image(of['summary'], 8, 170, w=86)
    pdf.image(of['importances'], 100, 170, w=100)
    # Pages 2-end, Per-fold assessment
    for i,(cm,rep) in enumerate(zip(of['conf_matrices'], of['cl_rep'])):
        pdf.add_page()
        pdf.set_xy(0, 16)
        pdf.set_font("helvetica", '', 14)
        pdf.cell(0, txt="Per-fold assessment - Fold #{}".format(str(i+1).zfill(2)), align='C')
        pdf.image(cm, 14,24,h=120)
        table_data = [["Class", "Precision", "Recall", "F1-score", "Support"]]
        for k in list(rep.keys())[:-3] + list(rep.keys())[-2:]:
            table_data.append([k, rep[k]['precision'], rep[k]['recall'], rep[k]['f1-score'], rep[k]['support']])
        pdf.set_xy(0, 150)
        pdf.set_font("helvetica", '', 6)
        with pdf.table(width=140, col_widths=(60,20,20,20,20)) as table:
            row = table.row()
            for datum in table_data[0]:
                row.cell(datum, align='C')
            for data_row in table_data[1:-2]:
                row = table.row()
                row.cell(data_row[0])
                for datum in data_row[1:4]:
                    row.cell('{:.4f}'.format(datum), align='R')
                row.cell('{}'.format(data_row[-1]), align='R')
            row = table.row()
            row.cell('Summary')
            for data_row in table_data[-2:]:
                row = table.row()
                row.cell(data_row[0])
                for datum in data_row[1:4]:
                    row.cell('{:.4f}'.format(datum), align='R')
                row.cell('{}'.format(data_row[-1]), align='R')
            row = table.row()
            row.cell('Accuracy')
            row.cell('{:.2f}%'.format(rep['accuracy']*100), align='R')


    pdf.output(out_pdf)