classificationWorkflow.py 5.07 KiB
import ogr
import sys
import subprocess
import platform
import numpy as np

def getFeaturesFields(shp,flds_pref):

    ds = ogr.Open(shp, 0)
    ly = ds.GetLayer(0)

    flds = []
    ldfn = ly.GetLayerDefn()
    for n in range(ldfn.GetFieldCount()):
        fn = ldfn.GetFieldDefn(n).name
        if fn.startswith(tuple(flds_pref)):
            flds.append(fn)

    ds = None

    return flds

def roughFix(shp,flds):

    ds = ogr.Open(shp,1)
    ly = ds.GetLayer(0)

    arr = np.empty([ly.GetFeatureCount(), len(flds)])
    i = 0
    for f in ly:
        j = 0
        for fld in flds:
            arr[i,j] = f.GetFieldAsDouble(fld)
            j += 1
        i += 1

    #R-like rough fix
    arr[np.where(arr==-9999.0)] = np.nan
    mns = np.tile(np.nanmean(arr,axis=0),[ly.GetFeatureCount(),1])
    arr[np.isnan(arr)] = mns[np.isnan(arr)]

    ly.ResetReading()
    i = 0
    for f in ly:
        j = 0
        for fld in flds:
            f.SetField(fld,arr[i, j])
            ly.SetFeature(f)
            j += 1
        i += 1

    ds = None

    return


def training(shp,code,model_fld,params,feat,feat_mode = 'list'):

    # Platform dependent parameters
    if platform.system() == 'Linux':
        sh = False
    elif platform.system() == 'Windows':
        sh = True
    else:
        sys.exit("Platform not supported!")

    if '-classifier' in params:
        classifier = params[params.index('-classifier') + 1]
    else:
        classifier = 'libsvm'
    model_file = model_fld + '/' + classifier + '_' + code + '.model'
    confmat_file = model_fld + '/' + classifier + '_' + code + '.confmat.txt'
    stat_file = model_fld + '/GT_stats.xml'

    if feat_mode == 'prefix':
        flds = getFeaturesFields(shp, feat)
    elif feat_mode == 'list':
        flds = feat
    else:
        sys.exit('ERROR: mode ' + feat_mode + ' not valid.')

    if platform.system() == 'Linux':
        cmd = ['otbcli_ComputeVectorFeaturesStatistics','-io.vd',shp,'-io.stats',stat_file,'-feat'] + flds
        subprocess.call(cmd,shell=sh)
    elif platform.system() == 'Windows':
        import otbApplication
        app = otbApplication.Registry.CreateApplication('ComputeVectorFeaturesStatistics')
        app.SetParameterStringList('io.vd', [shp])
        app.SetParameterString('io.stats', model_fld + '/GT_stats.xml')
        app.UpdateParameters()
        app.SetParameterStringList('feat',flds)
        app.ExecuteAndWriteOutput()
    else:
        sys.exit("Platform not supported!")

    if platform.system() == 'Linux':
        cmd = ['otbcli_TrainVectorClassifier', '-io.vd', shp, '-io.stats', model_fld + '/GT_stats.xml', '-io.confmatout', confmat_file, '-cfield', code, '-io.out', model_file, '-feat'] + flds + params
        subprocess.call(cmd,shell=sh)
    elif platform.system() == 'Windows':
        import otbApplication
        app = otbApplication.Registry.CreateApplication('TrainVectorClassifier')
        app.SetParameterStringList('io.vd', [shp])
        app.SetParameterString('io.stats', model_fld + '/GT_stats.xml')
        app.SetParameterString('io.confmatout', confmat_file)
        app.SetParameterString('io.out', model_file)
        app.UpdateParameters()
        app.SetParameterStringList('cfield', [code])
        app.SetParameterStringList('feat', flds)
        # Parse classification parameters string
        # WARNING: works for all classifier with single value parameters
        # (surely hangs on <string list> classification parameter types - e.g. -classifier.ann.sizes)
        cl_param_keys = params[0::2]
        cl_param_vals = params[1::2]
        for prk,prv in zip(cl_param_keys,cl_param_vals):
            app.SetParameterString(prk[1:],prv)
        app.UpdateParameters()
        app.ExecuteAndWriteOutput()
    else:
        sys.exit("Platform not supported!")

    return flds

def classify(shp_list,code,stat_file,model_file,feat,feat_mode = 'list'):

    # Platform dependent parameters
    if platform.system() == 'Linux':
        sh = False
    elif platform.system() == 'Windows':
        sh = True
    else:
        sys.exit("Platform not supported!")

    if feat_mode == 'prefix':
        flds = getFeaturesFields(shp_list[0], feat)
    elif feat_mode == 'list':
        flds = feat
    else:
        sys.exit('ERROR: mode ' + feat_mode + ' not valid.')

    for shp in shp_list:
        #roughFix(shp, flds)
        if platform.system() == 'Linux':
            cmd = ['otbcli_VectorClassifier','-in',shp,'-instat',stat_file,'-model',model_file,'-cfield',code,'-feat'] + flds
            subprocess.call(cmd,shell=sh)
        elif platform.system() == 'Windows':
            import otbApplication
            app = otbApplication.Registry.CreateApplication('VectorClassifier')
            app.SetParameterString('in',shp)
            app.SetParameterString('instat', stat_file)
            app.SetParameterString('model', model_file)
            app.SetParameterString('cfield', code)
            app.UpdateParameters()
            app.SetParameterStringList('feat',flds)
            app.UpdateParameters()
            app.ExecuteAndWriteOutput()
        else:
            sys.exit('Platform not supported!')

    return