classificationWorkflow.py 2.61 KiB
import ogr
import sys
import subprocess
import platform
import numpy as np

def getFeaturesFields(shp,flds_pref):

    ds = ogr.Open(shp, 0)
    ly = ds.GetLayer(0)

    flds = []
    ldfn = ly.GetLayerDefn()
    for n in range(ldfn.GetFieldCount()):
        fn = ldfn.GetFieldDefn(n).name
        if fn.startswith(tuple(flds_pref)):
            flds.append(fn)

    ds = None

    return flds

def roughFix(shp,flds):

    ds = ogr.Open(shp,1)
    ly = ds.GetLayer(0)

    arr = np.empty([ly.GetFeatureCount(), len(flds)])
    i = 0
    for f in ly:
        j = 0
        for fld in flds:
            arr[i,j] = f.GetFieldAsDouble(fld)
            j += 1
        i += 1

    #R-like rough fix
    arr[np.where(arr==-9999.0)] = np.nan
    mns = np.tile(np.nanmean(arr,axis=0),[ly.GetFeatureCount(),1])
    arr[np.isnan(arr)] = mns[np.isnan(arr)]

    ly.ResetReading()
    i = 0
    for f in ly:
        j = 0
        for fld in flds:
            f.SetField(fld,arr[i, j])
            ly.SetFeature(f)
            j += 1
        i += 1

    ds = None

    return


def training(shp,feat_prefix,code,model_fld,params):

    # Platform dependent parameters
    if platform.system() == 'Linux':
        sh = False
    elif platform.system() == 'Windows':
        sh = True
    else:
        sys.exit("Platform not supported!")

    if '-classifier' in params:
        classifier = params[params.index('-classifier') + 1]
    else:
        classifier = 'libsvm'
    model_file = model_fld + '/' + classifier + '_' + code + '.model'
    confmat_file = model_fld + '/' + classifier + '_' + code + '.confmat.txt'
    stat_file = model_fld + '/GT_stats.xml'

    flds = getFeaturesFields(shp,feat_prefix)
    roughFix(shp,flds)

    cmd = ['otbcli_ComputeVectorFeaturesStatistics','-io.vd',shp,'-io.stats',stat_file,'-feat'] + flds
    subprocess.call(cmd,shell=sh)

    cmd = ['otbcli_TrainVectorClassifier', '-io.vd', shp, '-io.stats', model_fld + '/GT_stats.xml',
           '-io.confmatout', confmat_file, '-cfield', code, '-io.out', model_file, '-feat'] + flds + params
    subprocess.call(cmd,shell=sh)

    return flds

def classify(shp_list,feat_prefix,code,stat_file,model_file):

    # Platform dependent parameters
    if platform.system() == 'Linux':
        sh = False
    elif platform.system() == 'Windows':
        sh = True
    else:
        sys.exit("Platform not supported!")

    flds = getFeaturesFields(shp_list[0], feat_prefix)
    for shp in shp_list:
        #roughFix(shp, flds)
        cmd = ['otbcli_VectorClassifier','-in',shp,'-instat',stat_file,'-model',model_file,'-cfield',code,'-feat'] + flds
        subprocess.call(cmd,shell=sh)

    return