From 1b2823d1dcf64aa988c822a74c4751bef4de2972 Mon Sep 17 00:00:00 2001
From: "raffaele.gaetano" <raffaele.gaetano@cirad.fr>
Date: Tue, 21 May 2019 20:39:42 +0200
Subject: [PATCH] ENH: k-fold report to latex table.

---
 validationFramework.py | 105 ++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 104 insertions(+), 1 deletion(-)

diff --git a/validationFramework.py b/validationFramework.py
index 3d98e02..c80c7c5 100644
--- a/validationFramework.py
+++ b/validationFramework.py
@@ -134,7 +134,6 @@ def formatValidationTxt(classes,cm,acc,kappa,prf,txt_out):
 
     tid.close()
 
-
 def genKFolds(shp,fld,k,out_fld=None):
     # Read all features and store classes in a separate array
     ds = ogr.Open(shp)
@@ -251,6 +250,110 @@ def kFoldReport(fscores,accs,kappas,txt_out):
 
     tid.close()
 
+def readKFoldReport(fn,cln,tag = None):
+    fid = open(fn, 'rb')
+    cid = open(cln, 'rb')
+    # read report
+    classes = []
+    fsc_mean = []
+    fsc_std = []
+    oa_mean = None
+    oa_std = None
+    kc_mean = None
+    kc_std = None
+
+    clnames = {}
+    notfound = tag is not None
+    cidl = cid.read().splitlines()
+    for cl in cidl:
+        if tag is not None and notfound:
+            if cl != tag:
+                continue
+            if cl == tag and notfound:
+                notfound = False
+                continue
+
+        scl = cl.split(',')
+        if len(scl) == 2:
+            clnames[int(scl[0])] = scl[1]
+        else:
+            break
+
+    for l in fid:
+        line = l.split()
+        if len(line) == 0:
+            continue
+        if line[0] == 'Class':
+            classes.append(int(line[1]))
+            fsc_mean.append(float(line[3]))
+            fsc_std.append(float(line[5]))
+        elif line[0] == 'Overall':
+            oa_mean = float(line[2][:-1])
+            oa_std = float(line[4][:-1])
+        elif line[0] == 'Kappa':
+            kc_mean = float(line[2])
+            kc_std = float(line[4])
+
+    out = {}
+    out['PerClass'] = {}
+    out['OverallAcc'] = [oa_mean,oa_std]
+    out['Kappa'] = [kc_mean,kc_std]
+    out['ClassDict'] = clnames
+
+    for i in range(len(classes)):
+        out['PerClass'][clnames[classes[i]]] = [fsc_mean[i],fsc_std[i]]
+
+    fid.close()
+    cid.close()
+
+    return out
+
+def kFoldReportToLatexTable(fn,cln,tag=None,ofn=None,mode='vertical'):
+    dct = readKFoldReport(fn,cln,tag)
+    if ofn == None:
+        ofn = fn.replace('.txt','.tex')
+    oid = open(ofn,'wb')
+    oid.write('\\documentclass{standalone}\n')
+    oid.write('\\usepackage[dvipsnames]{xcolor}\n')
+    oid.write('\\renewcommand\\familydefault{\\sfdefault}\n')
+    oid.write('\\begin{document}\n')
+
+    pcf = [dct['ClassDict'][i] for i in sorted(dct['ClassDict'])]
+
+    #def tabular
+    if mode == 'vertical':
+        oid.write('\\begin{tabular}{|c|c|}\n')
+        oid.write('\\hline\n')
+        oid.write('\\textbf{Class} & \\textbf{F-Score} \\\\\\hline\n')
+        for c in pcf:
+            clr = 'black'
+            if dct['PerClass'][c][0] < 0.3:
+                clr = 'red'
+            elif dct['PerClass'][c][0] >= 0.3 and dct['PerClass'][c][0] < 0.5:
+                clr = 'orange'
+            elif dct['PerClass'][c][0] > 0.75:
+                clr = 'ForestGreen'
+            oid.write('\\textit{%s} & \\color{%s}%1.4f$\\pm$%1.4f \\\\\\hline\n' % (c,clr,dct['PerClass'][c][0],dct['PerClass'][c][1]))
+        oid.write('\\hline\n')
+        clr = 'black'
+        if dct['OverallAcc'][0] < 0.3:
+            clr = 'red'
+        elif dct['OverallAcc'][0] >= 0.3 and dct['OverallAcc'][0] < 0.5:
+            clr = 'orange'
+        elif dct['OverallAcc'][0] > 0.75:
+            clr = 'ForestGreen'
+        oid.write('\\textbf{Overall Acc.} & {\\color{%s}\\textbf{%2.2f}\\%% $\\pm$ \\textbf{%2.2f}\\%%} \\\\\\hline\n' % (clr,dct['OverallAcc'][0],dct['OverallAcc'][1]))
+        oid.write('\\textbf{Kappa} & \\textbf{%0.4f} $\\pm$ \\textbf{%0.4f} \\\\\\hline\n' % (dct['Kappa'][0],dct['Kappa'][1]))
+        oid.write('\\end{tabular}')
+
+    oid.write('\\end{document}\n')
+    oid.close()
+    cdir = os.getcwd()
+    os.chdir(os.path.dirname(ofn))
+    subprocess.call(['pdflatex',ofn])
+    os.chdir(cdir)
+    return
+
 def getTrainingDataFromShapefile(shp,fields,code):
     ds = ogr.Open(shp)
     ly = ds.GetLayer(0)
-- 
GitLab