From 1b2823d1dcf64aa988c822a74c4751bef4de2972 Mon Sep 17 00:00:00 2001 From: "raffaele.gaetano" <raffaele.gaetano@cirad.fr> Date: Tue, 21 May 2019 20:39:42 +0200 Subject: [PATCH] ENH: k-fold report to latex table. --- validationFramework.py | 105 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/validationFramework.py b/validationFramework.py index 3d98e02..c80c7c5 100644 --- a/validationFramework.py +++ b/validationFramework.py @@ -134,7 +134,6 @@ def formatValidationTxt(classes,cm,acc,kappa,prf,txt_out): tid.close() - def genKFolds(shp,fld,k,out_fld=None): # Read all features and store classes in a separate array ds = ogr.Open(shp) @@ -251,6 +250,110 @@ def kFoldReport(fscores,accs,kappas,txt_out): tid.close() +def readKFoldReport(fn,cln,tag = None): + fid = open(fn, 'rb') + cid = open(cln, 'rb') + # read report + classes = [] + fsc_mean = [] + fsc_std = [] + oa_mean = None + oa_std = None + kc_mean = None + kc_std = None + + clnames = {} + notfound = tag is not None + cidl = cid.read().splitlines() + for cl in cidl: + if tag is not None and notfound: + if cl != tag: + continue + if cl == tag and notfound: + notfound = False + continue + + scl = cl.split(',') + if len(scl) == 2: + clnames[int(scl[0])] = scl[1] + else: + break + + for l in fid: + line = l.split() + if len(line) == 0: + continue + if line[0] == 'Class': + classes.append(int(line[1])) + fsc_mean.append(float(line[3])) + fsc_std.append(float(line[5])) + elif line[0] == 'Overall': + oa_mean = float(line[2][:-1]) + oa_std = float(line[4][:-1]) + elif line[0] == 'Kappa': + kc_mean = float(line[2]) + kc_std = float(line[4]) + + out = {} + out['PerClass'] = {} + out['OverallAcc'] = [oa_mean,oa_std] + out['Kappa'] = [kc_mean,kc_std] + out['ClassDict'] = clnames + + for i in range(len(classes)): + out['PerClass'][clnames[classes[i]]] = [fsc_mean[i],fsc_std[i]] + + fid.close() + cid.close() + + return out + +def kFoldReportToLatexTable(fn,cln,tag=None,ofn=None,mode='vertical'): + dct = readKFoldReport(fn,cln,tag) + if ofn == None: + ofn = fn.replace('.txt','.tex') + oid = open(ofn,'wb') + oid.write('\\documentclass{standalone}\n') + oid.write('\\usepackage[dvipsnames]{xcolor}\n') + oid.write('\\renewcommand\\familydefault{\\sfdefault}\n') + oid.write('\\begin{document}\n') + + pcf = [dct['ClassDict'][i] for i in sorted(dct['ClassDict'])] + + #def tabular + if mode == 'vertical': + oid.write('\\begin{tabular}{|c|c|}\n') + oid.write('\\hline\n') + oid.write('\\textbf{Class} & \\textbf{F-Score} \\\\\\hline\n') + for c in pcf: + clr = 'black' + if dct['PerClass'][c][0] < 0.3: + clr = 'red' + elif dct['PerClass'][c][0] >= 0.3 and dct['PerClass'][c][0] < 0.5: + clr = 'orange' + elif dct['PerClass'][c][0] > 0.75: + clr = 'ForestGreen' + oid.write('\\textit{%s} & \\color{%s}%1.4f$\\pm$%1.4f \\\\\\hline\n' % (c,clr,dct['PerClass'][c][0],dct['PerClass'][c][1])) + oid.write('\\hline\n') + clr = 'black' + if dct['OverallAcc'][0] < 0.3: + clr = 'red' + elif dct['OverallAcc'][0] >= 0.3 and dct['OverallAcc'][0] < 0.5: + clr = 'orange' + elif dct['OverallAcc'][0] > 0.75: + clr = 'ForestGreen' + oid.write('\\textbf{Overall Acc.} & {\\color{%s}\\textbf{%2.2f}\\%% $\\pm$ \\textbf{%2.2f}\\%%} \\\\\\hline\n' % (clr,dct['OverallAcc'][0],dct['OverallAcc'][1])) + oid.write('\\textbf{Kappa} & \\textbf{%0.4f} $\\pm$ \\textbf{%0.4f} \\\\\\hline\n' % (dct['Kappa'][0],dct['Kappa'][1])) + oid.write('\\end{tabular}') + + oid.write('\\end{document}\n') + oid.close() + cdir = os.getcwd() + os.chdir(os.path.dirname(ofn)) + subprocess.call(['pdflatex',ofn]) + os.chdir(cdir) + return + def getTrainingDataFromShapefile(shp,fields,code): ds = ogr.Open(shp) ly = ds.GetLayer(0) -- GitLab